This commit is contained in:
FilippoOlivo
2025-11-13 17:03:18 +01:00
parent 51a0399111
commit 0ee63686dd
3 changed files with 56 additions and 17 deletions

View File

@@ -275,14 +275,19 @@ class PinaDataModule(LightningDataModule):
),
module="lightning.pytorch.trainer.connectors.data_connector",
)
return PinaDataLoader(
dl = PinaDataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
common_batch_size=self.common_batch_size,
separate_conditions=self.separate_conditions,
device=self.trainer.strategy.root_device,
)
if self.batch_size is None:
# Override the method to transfer the batch to the device
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dl
def val_dataloader(self):
"""
@@ -325,7 +330,7 @@ class PinaDataModule(LightningDataModule):
:rtype: list[tuple]
"""
return batch
return [(k, v) for k, v in batch.items()]
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
@@ -383,9 +388,15 @@ class PinaDataModule(LightningDataModule):
to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None:
to_return["train"] = self.train_dataset.input
to_return["train"] = {
cond: data.input for cond, data in self.train_dataset.items()
}
if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input
to_return["val"] = {
cond: data.input for cond, data in self.val_dataset.items()
}
if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return["test"] = self.test_dataset.input
to_return["test"] = {
cond: data.input for cond, data in self.test_dataset.items()
}
return to_return