This commit is contained in:
FilippoOlivo
2025-01-23 15:04:15 +01:00
committed by Nicola Demo
parent 3ea05e845d
commit 6b122d8b2c

View File

@@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule):
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler, return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate) collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device return dataloader
def train_dataloader(self): def train_dataloader(self):
""" """
@@ -278,7 +278,6 @@ class PinaDataModule(LightningDataModule):
collate_fn=collate) collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
return dataloader return dataloader
def test_dataloader(self): def test_dataloader(self):
@@ -293,18 +292,18 @@ class PinaDataModule(LightningDataModule):
""" """
raise NotImplementedError("Predict dataloader not implemented") raise NotImplementedError("Predict dataloader not implemented")
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx): def transfer_batch_to_device(self, batch, device, dataloader_idx):
""" """
Transfer the batch to the device. This method is called in the Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device. training loop and is used to transfer the batch to the device.
""" """
if isinstance(batch, list):
return batch
batch = [ batch = [
(k, super(LightningDataModule, self).transfer_batch_to_device(v, (k, super(LightningDataModule, self).transfer_batch_to_device(v,
device, device,
dataloader_idx)) dataloader_idx))
for k, v in batch.items() for k, v in batch.items()
] ]
return batch return batch