From 6b122d8b2cf32d0a7ef18af9c93b6724c62d3afb Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 23 Jan 2025 15:04:15 +0100 Subject: [PATCH] Bug fix PR #423 --- pina/data/data_module.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 4b529fe..da03266 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule): sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) return DataLoader(self.val_dataset, sampler=sampler, collate_fn=collate) - dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) + dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - self.transfer_batch_to_device = self.dummy_transfer_to_device + return dataloader def train_dataloader(self): """ @@ -278,7 +278,6 @@ class PinaDataModule(LightningDataModule): collate_fn=collate) dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - self.transfer_batch_to_device = self.dummy_transfer_to_device return dataloader def test_dataloader(self): @@ -293,18 +292,18 @@ class PinaDataModule(LightningDataModule): """ raise NotImplementedError("Predict dataloader not implemented") - def dummy_transfer_to_device(self, batch, device, dataloader_idx): - return batch - def transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. """ + if isinstance(batch, list): + return batch batch = [ (k, super(LightningDataModule, self).transfer_batch_to_device(v, device, dataloader_idx)) for k, v in batch.items() ] + return batch