This reverts commit b498797bfef45414d1a50147e3f1097b7179e5a8. Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
6b122d8b2c
commit
9aed1a30b3
@@ -256,9 +256,9 @@ class PinaDataModule(LightningDataModule):
|
|||||||
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
|
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
|
||||||
return DataLoader(self.val_dataset, sampler=sampler,
|
return DataLoader(self.val_dataset, sampler=sampler,
|
||||||
collate_fn=collate)
|
collate_fn=collate)
|
||||||
dataloader = DummyDataloader(self.val_dataset, self.trainer.strategy.root_device)
|
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||||
return dataloader
|
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
@@ -278,6 +278,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
collate_fn=collate)
|
collate_fn=collate)
|
||||||
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
|
||||||
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
|
||||||
|
self.transfer_batch_to_device = self.dummy_transfer_to_device
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
@@ -292,18 +293,18 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Predict dataloader not implemented")
|
raise NotImplementedError("Predict dataloader not implemented")
|
||||||
|
|
||||||
|
def dummy_transfer_to_device(self, batch, device, dataloader_idx):
|
||||||
|
return batch
|
||||||
|
|
||||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||||
"""
|
"""
|
||||||
Transfer the batch to the device. This method is called in the
|
Transfer the batch to the device. This method is called in the
|
||||||
training loop and is used to transfer the batch to the device.
|
training loop and is used to transfer the batch to the device.
|
||||||
"""
|
"""
|
||||||
if isinstance(batch, list):
|
|
||||||
return batch
|
|
||||||
batch = [
|
batch = [
|
||||||
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
|
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
|
||||||
device,
|
device,
|
||||||
dataloader_idx))
|
dataloader_idx))
|
||||||
for k, v in batch.items()
|
for k, v in batch.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
Reference in New Issue
Block a user