equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
179
pina/utils.py
179
pina/utils.py
@@ -98,63 +98,146 @@ def is_function(f):
|
||||
return type(f) == types.FunctionType or type(f) == types.LambdaType
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
def chebyshev_roots(n):
|
||||
"""
|
||||
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
|
||||
|
||||
def __init__(self, pinn) -> None:
|
||||
self.pinn = pinn
|
||||
:param int n: number of roots
|
||||
:return: roots
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
k = torch.arange(n)
|
||||
nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
|
||||
return nodes
|
||||
|
||||
@property
|
||||
def dataloader(self):
|
||||
return self._create_dataloader()
|
||||
# class PinaDataset():
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return [self.SampleDataset(key, val)
|
||||
for key, val in self.input_pts.items()]
|
||||
# def __init__(self, pinn) -> None:
|
||||
# self.pinn = pinn
|
||||
|
||||
def _create_dataloader(self):
|
||||
"""Private method for creating dataloader
|
||||
# @property
|
||||
# def dataloader(self):
|
||||
# return self._create_dataloader()
|
||||
|
||||
:return: dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
if self.pinn.batch_size is None:
|
||||
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
# @property
|
||||
# def dataset(self):
|
||||
# return [self.SampleDataset(key, val)
|
||||
# for key, val in self.input_pts.items()]
|
||||
|
||||
def custom_collate(batch):
|
||||
# extracting pts labels
|
||||
_, pts = list(batch[0].items())[0]
|
||||
labels = pts.labels
|
||||
# calling default torch collate
|
||||
collate_res = default_collate(batch)
|
||||
# save collate result in dict
|
||||
res = {}
|
||||
for key, val in collate_res.items():
|
||||
val.labels = labels
|
||||
res[key] = val
|
||||
return res
|
||||
# def _create_dataloader(self):
|
||||
# """Private method for creating dataloader
|
||||
|
||||
# creating dataset, list of dataset for each location
|
||||
datasets = [self.SampleDataset(key, val)
|
||||
for key, val in self.pinn.input_pts.items()]
|
||||
# creating dataloader
|
||||
dataloaders = [DataLoader(dataset=dat,
|
||||
batch_size=self.pinn.batch_size,
|
||||
collate_fn=custom_collate)
|
||||
for dat in datasets]
|
||||
# :return: dataloader
|
||||
# :rtype: torch.utils.data.DataLoader
|
||||
# """
|
||||
# if self.pinn.batch_size is None:
|
||||
# return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
|
||||
return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
# def custom_collate(batch):
|
||||
# # extracting pts labels
|
||||
# _, pts = list(batch[0].items())[0]
|
||||
# labels = pts.labels
|
||||
# # calling default torch collate
|
||||
# collate_res = default_collate(batch)
|
||||
# # save collate result in dict
|
||||
# res = {}
|
||||
# for key, val in collate_res.items():
|
||||
# val.labels = labels
|
||||
# res[key] = val
|
||||
# __init__(self, location, tensor):
|
||||
# self._tensor = tensor
|
||||
# self._location = location
|
||||
# self._len = len(tensor)
|
||||
|
||||
class SampleDataset(torch.utils.data.Dataset):
|
||||
# def __getitem__(self, index):
|
||||
# tensor = self._tensor.select(0, index)
|
||||
# return {self._location: tensor}
|
||||
|
||||
def __init__(self, location, tensor):
|
||||
self._tensor = tensor
|
||||
self._location = location
|
||||
self._len = len(tensor)
|
||||
# def __len__(self):
|
||||
# return self._len
|
||||
|
||||
def __getitem__(self, index):
|
||||
tensor = self._tensor.select(0, index)
|
||||
return {self._location: tensor}
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
for label in self.labels:
|
||||
sample_tensor = getattr(self, label)
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
# print('porcodio')
|
||||
# print(sample_tensor.shape[1])
|
||||
# print(index)
|
||||
# print(sample_tensor[index])
|
||||
try:
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return max([len(getattr(self, label)) for label in self.labels])
|
||||
|
||||
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
|
||||
# class SampleDataset(torch.utils.data.Dataset):
|
||||
|
||||
# def __init__(self, location, tensor):
|
||||
# self._tensor = tensor
|
||||
# self._location = location
|
||||
# self._len = len(tensor)
|
||||
|
||||
# def __getitem__(self, index):
|
||||
# tensor = self._tensor.select(0, index)
|
||||
# return {self._location: tensor}
|
||||
|
||||
# def __len__(self):
|
||||
# return self._len
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
for label in self.labels:
|
||||
sample_tensor = getattr(self, label)
|
||||
|
||||
# print('porcodio')
|
||||
# print(sample_tensor.shape[1])
|
||||
# print(index)
|
||||
# print(sample_tensor[index])
|
||||
try:
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return max([len(getattr(self, label)) for label in self.labels])
|
||||
|
||||
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
Reference in New Issue
Block a user