fix tests and modules
This commit is contained in:
@@ -27,8 +27,7 @@ class PinaDataModule(LightningDataModule):
|
||||
val_size=0.1,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
common_batch_size=True,
|
||||
separate_conditions=False,
|
||||
batching_mode="common_batch_size",
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
@@ -84,8 +83,7 @@ class PinaDataModule(LightningDataModule):
|
||||
# Store fixed attributes
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.common_batch_size = common_batch_size
|
||||
self.separate_conditions = separate_conditions
|
||||
self.batching_mode = batching_mode
|
||||
self.automatic_batching = automatic_batching
|
||||
|
||||
# If batch size is None, num_workers has no effect
|
||||
@@ -280,8 +278,7 @@ class PinaDataModule(LightningDataModule):
|
||||
batch_size=self.batch_size,
|
||||
shuffle=self.shuffle,
|
||||
num_workers=self.num_workers,
|
||||
common_batch_size=self.common_batch_size,
|
||||
separate_conditions=self.separate_conditions,
|
||||
batching_mode=self.batching_mode,
|
||||
device=self.trainer.strategy.root_device,
|
||||
)
|
||||
if self.batch_size is None:
|
||||
@@ -330,7 +327,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:rtype: list[tuple]
|
||||
"""
|
||||
|
||||
return [(k, v) for k, v in batch.items()]
|
||||
return list(batch.items())
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user