Revert "Bug fix PR #423" (#426)

This reverts commit b498797bfef45414d1a50147e3f1097b7179e5a8.

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Dario Coscia
2025-01-23 15:13:14 +01:00
committed by Nicola Demo
parent 6b122d8b2c
commit 9aed1a30b3

View File

@@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule):
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
return dataloader
self.transfer_batch_to_device = self.dummy_transfer_to_device
def train_dataloader(self):
"""
@@ -278,6 +278,7 @@ class PinaDataModule(LightningDataModule):
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
return dataloader
def test_dataloader(self):
@@ -292,18 +293,18 @@ class PinaDataModule(LightningDataModule):
"""
raise NotImplementedError("Predict dataloader not implemented")
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
"""
if isinstance(batch, list):
return batch
batch = [
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
device,
dataloader_idx))
for k, v in batch.items()
]
return batch