batch_enhancement (#51)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
"""Utils module"""
|
||||
from functools import reduce
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, default_collate, ConcatDataset
|
||||
|
||||
from .label_tensor import LabelTensor
|
||||
|
||||
|
||||
def number_parameters(model, aggregate=True, only_trainable=True): #TODO: check
|
||||
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
||||
"""
|
||||
Return the number of parameters of a given `model`.
|
||||
|
||||
@@ -43,5 +45,67 @@ def merge_two_tensors(tensor1, tensor2):
|
||||
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
labels=tensor2.labels)
|
||||
return tensor1.append(tensor2)
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
|
||||
def __init__(self, pinn) -> None:
|
||||
self.pinn = pinn
|
||||
|
||||
@property
|
||||
def dataloader(self):
|
||||
return self._create_dataloader()
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return [self.SampleDataset(key, val)
|
||||
for key, val in self.input_pts.items()]
|
||||
|
||||
def _create_dataloader(self):
|
||||
"""Private method for creating dataloader
|
||||
|
||||
:return: dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
if self.pinn.batch_size is None:
|
||||
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
|
||||
def custom_collate(batch):
|
||||
# extracting pts labels
|
||||
_, pts = list(batch[0].items())[0]
|
||||
labels = pts.labels
|
||||
# calling default torch collate
|
||||
collate_res = default_collate(batch)
|
||||
# save collate result in dict
|
||||
res = {}
|
||||
for key, val in collate_res.items():
|
||||
val.labels = labels
|
||||
res[key] = val
|
||||
return res
|
||||
|
||||
# creating dataset, list of dataset for each location
|
||||
datasets = [self.SampleDataset(key, val)
|
||||
for key, val in self.pinn.input_pts.items()]
|
||||
# creating dataloader
|
||||
dataloaders = [DataLoader(dataset=dat,
|
||||
batch_size=self.pinn.batch_size,
|
||||
collate_fn=custom_collate)
|
||||
for dat in datasets]
|
||||
|
||||
return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
|
||||
class SampleDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, location, tensor):
|
||||
self._tensor = tensor
|
||||
self._location = location
|
||||
self._len = len(tensor)
|
||||
|
||||
def __getitem__(self, index):
|
||||
tensor = self._tensor.select(0, index)
|
||||
return {self._location: tensor}
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
Reference in New Issue
Block a user