fix
This commit is contained in:
@@ -275,14 +275,19 @@ class PinaDataModule(LightningDataModule):
|
||||
),
|
||||
module="lightning.pytorch.trainer.connectors.data_connector",
|
||||
)
|
||||
return PinaDataLoader(
|
||||
dl = PinaDataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=self.shuffle,
|
||||
num_workers=self.num_workers,
|
||||
common_batch_size=self.common_batch_size,
|
||||
separate_conditions=self.separate_conditions,
|
||||
device=self.trainer.strategy.root_device,
|
||||
)
|
||||
if self.batch_size is None:
|
||||
# Override the method to transfer the batch to the device
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||
return dl
|
||||
|
||||
def val_dataloader(self):
|
||||
"""
|
||||
@@ -325,7 +330,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:rtype: list[tuple]
|
||||
"""
|
||||
|
||||
return batch
|
||||
return [(k, v) for k, v in batch.items()]
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
@@ -383,9 +388,15 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
to_return["train"] = self.train_dataset.input
|
||||
to_return["train"] = {
|
||||
cond: data.input for cond, data in self.train_dataset.items()
|
||||
}
|
||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||
to_return["val"] = self.val_dataset.input
|
||||
to_return["val"] = {
|
||||
cond: data.input for cond, data in self.val_dataset.items()
|
||||
}
|
||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||
to_return["test"] = self.test_dataset.input
|
||||
to_return["test"] = {
|
||||
cond: data.input for cond, data in self.test_dataset.items()
|
||||
}
|
||||
return to_return
|
||||
|
||||
Reference in New Issue
Block a user