equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

@@ -98,63 +98,146 @@ def is_function(f):
return type(f) == types.FunctionType or type(f) == types.LambdaType
class PinaDataset():
def chebyshev_roots(n):
"""
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
def __init__(self, pinn) -> None:
self.pinn = pinn
:param int n: number of roots
:return: roots
:rtype: torch.tensor
"""
pi = torch.acos(torch.zeros(1)).item() * 2
k = torch.arange(n)
nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
return nodes
@property
def dataloader(self):
return self._create_dataloader()
# class PinaDataset():
@property
def dataset(self):
return [self.SampleDataset(key, val)
for key, val in self.input_pts.items()]
# def __init__(self, pinn) -> None:
# self.pinn = pinn
def _create_dataloader(self):
"""Private method for creating dataloader
# @property
# def dataloader(self):
# return self._create_dataloader()
:return: dataloader
:rtype: torch.utils.data.DataLoader
"""
if self.pinn.batch_size is None:
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
# @property
# def dataset(self):
# return [self.SampleDataset(key, val)
# for key, val in self.input_pts.items()]
def custom_collate(batch):
# extracting pts labels
_, pts = list(batch[0].items())[0]
labels = pts.labels
# calling default torch collate
collate_res = default_collate(batch)
# save collate result in dict
res = {}
for key, val in collate_res.items():
val.labels = labels
res[key] = val
return res
# def _create_dataloader(self):
# """Private method for creating dataloader
# creating dataset, list of dataset for each location
datasets = [self.SampleDataset(key, val)
for key, val in self.pinn.input_pts.items()]
# creating dataloader
dataloaders = [DataLoader(dataset=dat,
batch_size=self.pinn.batch_size,
collate_fn=custom_collate)
for dat in datasets]
# :return: dataloader
# :rtype: torch.utils.data.DataLoader
# """
# if self.pinn.batch_size is None:
# return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
return dict(zip(self.pinn.input_pts.keys(), dataloaders))
# def custom_collate(batch):
# # extracting pts labels
# _, pts = list(batch[0].items())[0]
# labels = pts.labels
# # calling default torch collate
# collate_res = default_collate(batch)
# # save collate result in dict
# res = {}
# for key, val in collate_res.items():
# val.labels = labels
# res[key] = val
# __init__(self, location, tensor):
# self._tensor = tensor
# self._location = location
# self._len = len(tensor)
class SampleDataset(torch.utils.data.Dataset):
# def __getitem__(self, index):
# tensor = self._tensor.select(0, index)
# return {self._location: tensor}
def __init__(self, location, tensor):
self._tensor = tensor
self._location = location
self._len = len(tensor)
# def __len__(self):
# return self._len
def __getitem__(self, index):
tensor = self._tensor.select(0, index)
return {self._location: tensor}
from torch.utils.data import Dataset, DataLoader
class LabelTensorDataset(Dataset):
def __init__(self, d):
for k, v in d.items():
setattr(self, k, v)
self.labels = list(d.keys())
def __getitem__(self, index):
print(index)
result = {}
for label in self.labels:
sample_tensor = getattr(self, label)
def __len__(self):
return self._len
# print('porcodio')
# print(sample_tensor.shape[1])
# print(index)
# print(sample_tensor[index])
try:
result[label] = sample_tensor[index]
except IndexError:
result[label] = torch.tensor([])
print(result)
return result
def __len__(self):
return max([len(getattr(self, label)) for label in self.labels])
class LabelTensorDataLoader(DataLoader):
def collate_fn(self, data):
print(data)
gggggggggg
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
# class SampleDataset(torch.utils.data.Dataset):
# def __init__(self, location, tensor):
# self._tensor = tensor
# self._location = location
# self._len = len(tensor)
# def __getitem__(self, index):
# tensor = self._tensor.select(0, index)
# return {self._location: tensor}
# def __len__(self):
# return self._len
from torch.utils.data import Dataset, DataLoader
class LabelTensorDataset(Dataset):
def __init__(self, d):
for k, v in d.items():
setattr(self, k, v)
self.labels = list(d.keys())
def __getitem__(self, index):
print(index)
result = {}
for label in self.labels:
sample_tensor = getattr(self, label)
# print('porcodio')
# print(sample_tensor.shape[1])
# print(index)
# print(sample_tensor[index])
try:
result[label] = sample_tensor[index]
except IndexError:
result[label] = torch.tensor([])
print(result)
return result
def __len__(self):
return max([len(getattr(self, label)) for label in self.labels])
class LabelTensorDataLoader(DataLoader):
def collate_fn(self, data):
print(data)
gggggggggg