Revert "Bug fix PR #423" (#426)

This reverts commit b498797bfef45414d1a50147e3f1097b7179e5a8.

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Dario Coscia
2025-01-23 15:13:14 +01:00
committed by Nicola Demo
parent 6b122d8b2c
commit 9aed1a30b3

View File

@@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule):
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler, return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate) collate_fn=collate)
dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device) dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
return dataloader self.transfer_batch_to_device = self.dummy_transfer_to_device
def train_dataloader(self): def train_dataloader(self):
""" """
@@ -278,6 +278,7 @@ class PinaDataModule(LightningDataModule):
collate_fn=collate) collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
return dataloader return dataloader
def test_dataloader(self): def test_dataloader(self):
@@ -292,18 +293,18 @@ class PinaDataModule(LightningDataModule):
""" """
raise NotImplementedError("Predict dataloader not implemented") raise NotImplementedError("Predict dataloader not implemented")
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx): def transfer_batch_to_device(self, batch, device, dataloader_idx):
""" """
Transfer the batch to the device. This method is called in the Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device. training loop and is used to transfer the batch to the device.
""" """
if isinstance(batch, list):
return batch
batch = [ batch = [
(k, super(LightningDataModule, self).transfer_batch_to_device(v, (k, super(LightningDataModule, self).transfer_batch_to_device(v,
device, device,
dataloader_idx)) dataloader_idx))
for k, v in batch.items() for k, v in batch.items()
] ]
return batch return batch