This reverts commit b498797bfef45414d1a50147e3f1097b7179e5a8. Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
6b122d8b2c
commit
9aed1a30b3
@@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule):
|
||||
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
|
||||
return DataLoader(self.val_dataset, sampler=sampler,
|
||||
collate_fn=collate)
|
||||
dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device)
|
||||
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||
return dataloader
|
||||
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
@@ -278,6 +278,7 @@ class PinaDataModule(LightningDataModule):
|
||||
collate_fn=collate)
|
||||
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||
return dataloader
|
||||
|
||||
def test_dataloader(self):
|
||||
@@ -292,18 +293,18 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
raise NotImplementedError("Predict dataloader not implemented")
|
||||
|
||||
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
|
||||
return batch
|
||||
|
||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
"""
|
||||
if isinstance(batch, list):
|
||||
return batch
|
||||
batch = [
|
||||
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
|
||||
device,
|
||||
dataloader_idx))
|
||||
for k, v in batch.items()
|
||||
]
|
||||
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user