Skip to content

Commit

Permalink
Add possibility to use custom padding (address issue #458) (#489)
Browse files Browse the repository at this point in the history
* add possibility to use custom padding

* Add padding_dict to actual function

* Simplify arguments and fix docs

* typo
  • Loading branch information
IgnacioJPickering authored Jun 17, 2020
1 parent c251739 commit 1aa77d8
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions torchani/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
- `shuffle`
- `cache` cache the result of previous transformations.
- `collate` pad the dataset, convert it to tensor, and stack them
together to get a batch.
together to get a batch. `collate` uses a default padding dictionary
``{'species': -1, 'coordinates': 0.0, 'forces': 0.0, 'energies': 0.0}`` for
padding, but a custom padding dictionary can be passed as an optional
parameter, which overrides this default padding.
- `pin_memory` copy the tensor to pinned memory so that later transfer
to cuda could be faster.
Expand Down Expand Up @@ -94,8 +98,8 @@

verbose = True


PROPERTIES = ('energies',)

PADDING = {
'species': -1,
'coordinates': 0.0,
Expand All @@ -104,8 +108,11 @@
}


def collate_fn(samples):
return utils.stack_with_padding(samples, PADDING)
def collate_fn(samples, padding=None):
if padding is None:
padding = PADDING

return utils.stack_with_padding(samples, padding)


class IterableAdapter:
Expand Down Expand Up @@ -241,19 +248,22 @@ def cache(reenterable_iterable):
return ret

@staticmethod
def collate(reenterable_iterable, batch_size):
def reenterable_iterable_factory():
def collate(reenterable_iterable, batch_size, padding=None):
def reenterable_iterable_factory(padding=None):
batch = []
i = 0
for d in reenterable_iterable:
batch.append(d)
i += 1
if i == batch_size:
i = 0
yield collate_fn(batch)
yield collate_fn(batch, padding)
batch = []
if len(batch) > 0:
yield collate_fn(batch)
yield collate_fn(batch, padding)

reenterable_iterable_factory = functools.partial(reenterable_iterable_factory,
padding)
try:
length = (len(reenterable_iterable) + batch_size - 1) // batch_size
return IterableAdapterWithLength(reenterable_iterable_factory, length)
Expand Down

0 comments on commit 1aa77d8

Please sign in to comment.