This commit is contained in:
FilippoOlivo
2025-01-23 15:04:15 +01:00
committed by Nicola Demo
parent 3ea05e845d
commit 6b122d8b2c

View File

@@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule):
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
return dataloader
def train_dataloader(self):
"""
@@ -278,7 +278,6 @@ class PinaDataModule(LightningDataModule):
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
return dataloader
def test_dataloader(self):
@@ -293,18 +292,18 @@ class PinaDataModule(LightningDataModule):
"""
raise NotImplementedError("Predict dataloader not implemented")
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
"""
if isinstance(batch, list):
return batch
batch = [
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
device,
dataloader_idx))
for k, v in batch.items()
]
return batch