fix tests and modules

This commit is contained in:
FilippoOlivo
2025-11-14 16:52:10 +01:00
parent 8440a672a7
commit 43163fdf74
5 changed files with 47 additions and 33 deletions

View File

@@ -27,8 +27,7 @@ class PinaDataModule(LightningDataModule):
val_size=0.1,
batch_size=None,
shuffle=True,
common_batch_size=True,
separate_conditions=False,
batching_mode="common_batch_size",
automatic_batching=None,
num_workers=0,
pin_memory=False,
@@ -84,8 +83,7 @@ class PinaDataModule(LightningDataModule):
# Store fixed attributes
self.batch_size = batch_size
self.shuffle = shuffle
self.common_batch_size = common_batch_size
self.separate_conditions = separate_conditions
self.batching_mode = batching_mode
self.automatic_batching = automatic_batching
# If batch size is None, num_workers has no effect
@@ -280,8 +278,7 @@ class PinaDataModule(LightningDataModule):
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
common_batch_size=self.common_batch_size,
separate_conditions=self.separate_conditions,
batching_mode=self.batching_mode,
device=self.trainer.strategy.root_device,
)
if self.batch_size is None:
@@ -330,7 +327,7 @@ class PinaDataModule(LightningDataModule):
:rtype: list[tuple]
"""
return [(k, v) for k, v in batch.items()]
return list(batch.items())
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""