From 9aed1a30b34a42317202ef710cc528cbcc73de35 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:13:14 +0100 Subject: [PATCH] Revert "Bug fix PR #423" (#426) This reverts commit b498797bfef45414d1a50147e3f1097b7179e5a8. Co-authored-by: Filippo Olivo Co-authored-by: Dario Coscia --- pina/data/data_module.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index da03266..4b529fe 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule): sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) return DataLoader(self.val_dataset, sampler=sampler, collate_fn=collate) - dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device) + dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - return dataloader + self.transfer_batch_to_device = self.dummy_transfer_to_device def train_dataloader(self): """ @@ -278,6 +278,7 @@ class PinaDataModule(LightningDataModule): collate_fn=collate) dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) + self.transfer_batch_to_device = self.dummy_transfer_to_device return dataloader def test_dataloader(self): @@ -292,18 +293,18 @@ class PinaDataModule(LightningDataModule): """ raise NotImplementedError("Predict dataloader not implemented") + def dummy_transfer_to_device(self, batch, device, dataloader_idx): + return batch + def transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. """ - if isinstance(batch, list): - return batch batch = [ (k, super(LightningDataModule, self).transfer_batch_to_device(v, device, dataloader_idx)) for k, v in batch.items() ] - return batch