diff --git a/pina/data/data_module.py b/pina/data/data_module.py index da03266..4b529fe 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule): sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) return DataLoader(self.val_dataset, sampler=sampler, collate_fn=collate) - dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device) + dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - return dataloader + self.transfer_batch_to_device = self.dummy_transfer_to_device def train_dataloader(self): """ @@ -278,6 +278,7 @@ class PinaDataModule(LightningDataModule): collate_fn=collate) dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) + self.transfer_batch_to_device = self.dummy_transfer_to_device return dataloader def test_dataloader(self): @@ -292,18 +293,18 @@ class PinaDataModule(LightningDataModule): """ raise NotImplementedError("Predict dataloader not implemented") + def dummy_transfer_to_device(self, batch, device, dataloader_idx): + return batch + def transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. """ - if isinstance(batch, list): - return batch batch = [ (k, super(LightningDataModule, self).transfer_batch_to_device(v, device, dataloader_idx)) for k, v in batch.items() ] - return batch