equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
@@ -3,14 +3,14 @@ __all__ = [
|
||||
'LabelTensor',
|
||||
'Plotter',
|
||||
'Condition',
|
||||
'CartesianDomain',
|
||||
'Location',
|
||||
'CartesianDomain'
|
||||
]
|
||||
|
||||
from .meta import *
|
||||
from .label_tensor import LabelTensor
|
||||
from .pinn import PINN
|
||||
from .plotter import Plotter
|
||||
from .cartesian import CartesianDomain
|
||||
from .condition import Condition
|
||||
from .location import Location
|
||||
from .geometry import Location
|
||||
from .geometry import CartesianDomain
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def chebyshev_roots(n):
|
||||
""" Return the roots of *n* Chebyshev polynomials (between [-1, 1]) """
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
k = torch.arange(n)
|
||||
nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
|
||||
return nodes
|
||||
@@ -1,6 +1,7 @@
|
||||
""" Condition module. """
|
||||
from .label_tensor import LabelTensor
|
||||
from .location import Location
|
||||
from .geometry import Location
|
||||
from .equation.equation import Equation
|
||||
|
||||
def dummy(a):
|
||||
"""Dummy function for testing purposes."""
|
||||
@@ -17,13 +18,13 @@ class Condition:
|
||||
case, the model is trained to produce the output points given the input
|
||||
points.
|
||||
|
||||
2. By specifying the location and the function of the condition; in such
|
||||
a case, the model is trained to minimize that function by evaluating it
|
||||
at some samples of the location.
|
||||
2. By specifying the location and the equation of the condition; in such
|
||||
a case, the model is trained to minimize the equation residual by
|
||||
evaluating it at some samples of the location.
|
||||
|
||||
3. By specifying the input points and the function of the condition; in
|
||||
such a case, the model is trained to minimize that function by
|
||||
evaluating it at the input points.
|
||||
3. By specifying the input points and the equation of the condition; in
|
||||
such a case, the model is trained to minimize the equation residual by
|
||||
evaluating it at the passed input points.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -40,15 +41,15 @@ class Condition:
|
||||
>>> output_points=example_output_pts)
|
||||
>>> Condition(
|
||||
>>> location=example_domain,
|
||||
>>> function=example_dirichlet)
|
||||
>>> equation=example_dirichlet)
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> function=example_dirichlet)
|
||||
>>> equation=example_dirichlet)
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
'input_points', 'output_points', 'location', 'function',
|
||||
'input_points', 'output_points', 'location', 'equation',
|
||||
'data_weight'
|
||||
]
|
||||
|
||||
@@ -70,8 +71,8 @@ class Condition:
|
||||
|
||||
if (
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and
|
||||
sorted(kwargs.keys()) != sorted(['location', 'function']) and
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'function'])
|
||||
sorted(kwargs.keys()) != sorted(['location', 'equation']) and
|
||||
sorted(kwargs.keys()) != sorted(['input_points', 'equation'])
|
||||
):
|
||||
raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.')
|
||||
|
||||
@@ -81,16 +82,8 @@ class Condition:
|
||||
raise TypeError('`output_points` must be a torch.Tensor.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'location', Location):
|
||||
raise TypeError('`location` must be a Location.')
|
||||
|
||||
if 'function' in kwargs:
|
||||
if not isinstance(kwargs['function'], list):
|
||||
kwargs['function'] = [kwargs['function']]
|
||||
|
||||
|
||||
for i, func in enumerate(kwargs['function']):
|
||||
if not callable(func):
|
||||
raise TypeError(
|
||||
f'`function[{i}]` must be a callable function.')
|
||||
if not self._dictvalue_isinstance(kwargs, 'equation', Equation):
|
||||
raise TypeError('`equation` must be a Equation.')
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
126
pina/dataset.py
Normal file
126
pina/dataset.py
Normal file
@@ -0,0 +1,126 @@
|
||||
class PinaDataset():
|
||||
|
||||
def __init__(self, pinn) -> None:
|
||||
self.pinn = pinn
|
||||
|
||||
@property
|
||||
def dataloader(self):
|
||||
return self._create_dataloader()
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return [self.SampleDataset(key, val)
|
||||
for key, val in self.input_pts.items()]
|
||||
|
||||
def _create_dataloader(self):
|
||||
"""Private method for creating dataloader
|
||||
|
||||
:return: dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
if self.pinn.batch_size is None:
|
||||
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
|
||||
def custom_collate(batch):
|
||||
# extracting pts labels
|
||||
_, pts = list(batch[0].items())[0]
|
||||
labels = pts.labels
|
||||
# calling default torch collate
|
||||
collate_res = default_collate(batch)
|
||||
# save collate result in dict
|
||||
res = {}
|
||||
for key, val in collate_res.items():
|
||||
val.labels = labels
|
||||
res[key] = val
|
||||
def __getitem__(self, index):
|
||||
tensor = self._tensor.select(0, index)
|
||||
return {self._location: tensor}
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
for label in self.labels:
|
||||
sample_tensor = getattr(self, label)
|
||||
|
||||
# print('porcodio')
|
||||
# print(sample_tensor.shape[1])
|
||||
# print(index)
|
||||
# print(sample_tensor[index])
|
||||
try:
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return max([len(getattr(self, label)) for label in self.labels])
|
||||
|
||||
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
|
||||
# class SampleDataset(torch.utils.data.Dataset):
|
||||
|
||||
# def __init__(self, location, tensor):
|
||||
# self._tensor = tensor
|
||||
# self._location = location
|
||||
# self._len = len(tensor)
|
||||
|
||||
# def __getitem__(self, index):
|
||||
# tensor = self._tensor.select(0, index)
|
||||
# return {self._location: tensor}
|
||||
|
||||
# def __len__(self):
|
||||
# return self._len
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
for label in self.labels:
|
||||
sample_tensor = getattr(self, label)
|
||||
|
||||
# print('porcodio')
|
||||
# print(sample_tensor.shape[1])
|
||||
# print(index)
|
||||
# print(sample_tensor[index])
|
||||
try:
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return max([len(getattr(self, label)) for label in self.labels])
|
||||
|
||||
class DummyLoader:
|
||||
|
||||
def __init__(self, data) -> None:
|
||||
self.data = [data]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
0
pina/equation/__init__.py
Normal file
0
pina/equation/__init__.py
Normal file
10
pina/equation/equation.py
Normal file
10
pina/equation/equation.py
Normal file
@@ -0,0 +1,10 @@
|
||||
""" Module """
|
||||
from .equation_interface import EquationInterface
|
||||
|
||||
class Equation(EquationInterface):
|
||||
|
||||
def __init__(self, equation):
|
||||
self.__equation = equation
|
||||
|
||||
def residual(self, input_, output_):
|
||||
return self.__equation(input_, output_)
|
||||
37
pina/equation/equation_factory.py
Normal file
37
pina/equation/equation_factory.py
Normal file
@@ -0,0 +1,37 @@
|
||||
""" Module """
|
||||
from .equation import Equation
|
||||
from ..operators import grad, div, nabla
|
||||
|
||||
|
||||
class FixedValue(Equation):
|
||||
|
||||
def __init__(self, value, components=None):
|
||||
def equation(input_, output_):
|
||||
if components is None:
|
||||
return output_ - value
|
||||
return output_.extract(components) - value
|
||||
super().__init__(equation)
|
||||
|
||||
|
||||
class FixedGradient(Equation):
|
||||
|
||||
def __init__(self, value, components=None, d=None):
|
||||
def equation(input_, output_):
|
||||
return grad(output_, input_, components=components, d=d) - value
|
||||
super().__init__(equation)
|
||||
|
||||
|
||||
class FixedFlux(Equation):
|
||||
|
||||
def __init__(self, value, components=None, d=None):
|
||||
def equation(input_, output_):
|
||||
return div(output_, input_, components=components, d=d) - value
|
||||
super().__init__(equation)
|
||||
|
||||
|
||||
class Laplace(Equation):
|
||||
|
||||
def __init__(self, components=None, d=None):
|
||||
def equation(input_, output_):
|
||||
return nabla(output_, input_, components=components, d=d)
|
||||
super().__init__(equation)
|
||||
13
pina/equation/equation_interface.py
Normal file
13
pina/equation/equation_interface.py
Normal file
@@ -0,0 +1,13 @@
|
||||
""" Module for EquationInterface class """
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class EquationInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
The abstract `AbstractProblem` class. All the class defining a PINA Problem
|
||||
should be inheritied from this class.
|
||||
|
||||
In the definition of a PINA problem, the fundamental elements are:
|
||||
the output variables, the condition(s), and the domain(s) where the
|
||||
conditions are applied.
|
||||
"""
|
||||
24
pina/equation/system_equation.py
Normal file
24
pina/equation/system_equation.py
Normal file
@@ -0,0 +1,24 @@
|
||||
""" Module """
|
||||
import torch
|
||||
from .equation import Equation
|
||||
|
||||
class SystemEquation(Equation):
|
||||
|
||||
def __init__(self, list_equation):
|
||||
if not isinstance(list_equation, list):
|
||||
raise TypeError('list_equation must be a list of functions')
|
||||
|
||||
self.equations = []
|
||||
for i, equation in enumerate(list_equation):
|
||||
if not callable(equation):
|
||||
raise TypeError('list_equation must be a list of functions')
|
||||
|
||||
self.equations.append(Equation(equation))
|
||||
|
||||
def residual(self, input_, output_):
|
||||
return torch.mean(
|
||||
torch.stack([
|
||||
equation.residual(input_, output_)
|
||||
for equation in self.equations
|
||||
]),
|
||||
dim=0)
|
||||
10
pina/geometry/__init__.py
Normal file
10
pina/geometry/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
__all__ = [
|
||||
'Location',
|
||||
'CartesianDomain',
|
||||
'EllipsoidDomain',
|
||||
]
|
||||
|
||||
from .location import Location
|
||||
from .cartesian import CartesianDomain
|
||||
from .ellipsoid import EllipsoidDomain
|
||||
from .difference_domain import Difference
|
||||
@@ -1,9 +1,8 @@
|
||||
from .chebyshev import chebyshev_roots
|
||||
import torch
|
||||
|
||||
from .location import Location
|
||||
from .label_tensor import LabelTensor
|
||||
from .utils import torch_lhs
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import torch_lhs, chebyshev_roots
|
||||
|
||||
|
||||
class CartesianDomain(Location):
|
||||
@@ -240,3 +239,31 @@ class CartesianDomain(Location):
|
||||
return _Nd_sampler(n, mode, variables)
|
||||
else:
|
||||
raise ValueError(f'mode={mode} is not valid.')
|
||||
|
||||
|
||||
def is_inside(self, point, check_border=False):
|
||||
"""Check if a point is inside the ellipsoid.
|
||||
|
||||
:param point: Point to be checked
|
||||
:type point: LabelTensor
|
||||
:param check_border: Check if the point is also on the frontier
|
||||
of the ellipsoid, default False.
|
||||
:type check_border: bool
|
||||
:return: Returning True if the point is inside, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
is_inside = []
|
||||
for variable, bound in self.range_.items():
|
||||
if variable in point.labels:
|
||||
if bound[0] <= point.extract([variable]) <= bound[1]:
|
||||
is_inside.append(True)
|
||||
else:
|
||||
is_inside.append(False)
|
||||
|
||||
return all(is_inside)
|
||||
|
||||
# TODO check the fixed_ dimensions
|
||||
# for variable, value in self.fixed_.items():
|
||||
# if variable in point.labels:
|
||||
# if not (point.extract[variable] == value):
|
||||
# return False
|
||||
27
pina/geometry/difference_domain.py
Normal file
27
pina/geometry/difference_domain.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Module for Location class."""
|
||||
|
||||
from .location import Location
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
class Difference(Location):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, first, second):
|
||||
|
||||
|
||||
self.first = first
|
||||
self.second = second
|
||||
|
||||
def sample(self, n, mode ='random', variables='all'):
|
||||
"""
|
||||
"""
|
||||
assert mode is 'random', 'Only random mode is implemented'
|
||||
|
||||
samples = []
|
||||
while len(samples) < n:
|
||||
sample = self.first.sample(1, 'random')
|
||||
if not self.second.is_inside(sample):
|
||||
samples.append(sample.tolist()[0])
|
||||
|
||||
import torch
|
||||
return LabelTensor(torch.tensor(samples), labels=['x', 'y'])
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
|
||||
from .location import Location
|
||||
from .label_tensor import LabelTensor
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class Ellipsoid(Location):
|
||||
class EllipsoidDomain(Location):
|
||||
"""PINA implementation of Ellipsoid domain."""
|
||||
|
||||
def __init__(self, ellipsoid_dict, sample_surface=False):
|
||||
@@ -98,7 +98,7 @@ class Ellipsoid(Location):
|
||||
# get axis ellipse
|
||||
list_dict_vals = list(self._axis.values())
|
||||
tmp = torch.tensor(list_dict_vals, dtype=torch.float)
|
||||
ax_sq = LabelTensor(tmp.reshape(1, -1), list(self._axis.keys()))
|
||||
ax_sq = LabelTensor(tmp.reshape(1, -1)**2, list(self._axis.keys()))
|
||||
|
||||
if not all([i in ax_sq.labels for i in point.labels]):
|
||||
raise ValueError('point labels different from constructor'
|
||||
@@ -8,7 +8,6 @@ class Location(metaclass=ABCMeta):
|
||||
Abstract Location class.
|
||||
Any geometry entity should inherit from this class.
|
||||
"""
|
||||
@property
|
||||
@abstractmethod
|
||||
def sample(self):
|
||||
"""
|
||||
@@ -1,5 +1,7 @@
|
||||
""" Module for LabelTensor """
|
||||
from typing import Any
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class LabelTensor(torch.Tensor):
|
||||
@@ -79,7 +81,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
@labels.setter
|
||||
def labels(self, labels):
|
||||
if len(labels) != self.shape[1]: # small check
|
||||
if len(labels) != self.shape[self.ndim - 1]: # small check
|
||||
raise ValueError(
|
||||
'the tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
@@ -140,7 +142,7 @@ class LabelTensor(torch.Tensor):
|
||||
except ValueError:
|
||||
raise ValueError(f'`{f}` not in the labels list')
|
||||
|
||||
new_data = self[:, indeces].float()
|
||||
new_data = super(Tensor, self.T).__getitem__(indeces).float().T
|
||||
new_labels = [self.labels[idx] for idx in indeces]
|
||||
|
||||
extracted_tensor = new_data.as_subclass(LabelTensor)
|
||||
@@ -183,6 +185,19 @@ class LabelTensor(torch.Tensor):
|
||||
new_tensor.labels = new_labels
|
||||
return new_tensor
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Return a copy of the selected tensor.
|
||||
"""
|
||||
selected_lt = super(Tensor, self).__getitem__(index)
|
||||
if hasattr(self, 'labels'):
|
||||
selected_lt.labels = self.labels
|
||||
|
||||
return selected_lt
|
||||
|
||||
def __len__(self) -> int:
|
||||
return super().__len__()
|
||||
|
||||
def __str__(self):
|
||||
if hasattr(self, 'labels'):
|
||||
s = f'labels({str(self.labels)})\n'
|
||||
|
||||
@@ -42,6 +42,7 @@ class Network(torch.nn.Module):
|
||||
output_variables, extra_features=None):
|
||||
super().__init__()
|
||||
|
||||
print('HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH')
|
||||
if extra_features is None:
|
||||
extra_features = []
|
||||
|
||||
@@ -49,6 +50,7 @@ class Network(torch.nn.Module):
|
||||
self._model = model
|
||||
self._input_variables = input_variables
|
||||
self._output_variables = output_variables
|
||||
print(output_variables)
|
||||
|
||||
# check model and input/output
|
||||
self._check_consistency()
|
||||
@@ -59,10 +61,11 @@ class Network(torch.nn.Module):
|
||||
:raises ValueError: Error in constructing the PINA network
|
||||
"""
|
||||
try:
|
||||
tmp = torch.rand((10, len(self._input_variables)))
|
||||
tmp = LabelTensor(tmp, self._input_variables)
|
||||
tmp = self.forward(tmp) # trying a forward pass
|
||||
tmp = LabelTensor(tmp, self._output_variables)
|
||||
pass
|
||||
# tmp = torch.rand((10, len(self._input_variables)))
|
||||
# tmp = LabelTensor(tmp, self._input_variables)
|
||||
# tmp = self.forward(tmp) # trying a forward pass
|
||||
# tmp = LabelTensor(tmp, self._output_variables)
|
||||
except:
|
||||
raise ValueError('Error in constructing the PINA network.'
|
||||
' Check compatibility of input/output'
|
||||
|
||||
@@ -188,7 +188,7 @@ def nabla(output_, input_, components=None, d=None, method='std'):
|
||||
result = torch.zeros(output_.shape[0], 1, device=output_.device)
|
||||
for i, label in enumerate(grad_output.labels):
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
result[:, 0] += gg[:, i]
|
||||
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(i) # TODO improve
|
||||
labels = [f'dd{components[0]}']
|
||||
|
||||
else:
|
||||
|
||||
241
pina/pinn.py
241
pina/pinn.py
@@ -5,7 +5,8 @@ import torch.optim.lr_scheduler as lrs
|
||||
from .problem import AbstractProblem
|
||||
from .model import Network
|
||||
from .label_tensor import LabelTensor
|
||||
from .utils import merge_tensors, PinaDataset
|
||||
from .utils import merge_tensors
|
||||
from .dataset import DummyLoader
|
||||
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
@@ -26,6 +27,7 @@ class PINN(object):
|
||||
batch_size=None,
|
||||
dtype=torch.float32,
|
||||
device='cpu',
|
||||
writer=None,
|
||||
error_norm='mse'):
|
||||
'''
|
||||
:param AbstractProblem problem: the formualation of the problem.
|
||||
@@ -84,17 +86,22 @@ class PINN(object):
|
||||
self.input_pts = {}
|
||||
|
||||
self.trained_epoch = 0
|
||||
|
||||
from .writer import Writer
|
||||
if writer is None:
|
||||
writer = Writer()
|
||||
self.writer = writer
|
||||
|
||||
if not optimizer_kwargs:
|
||||
optimizer_kwargs = {}
|
||||
optimizer_kwargs['lr'] = lr
|
||||
self.optimizer = optimizer(
|
||||
self.model.parameters(), weight_decay=regularizer, **optimizer_kwargs)
|
||||
self._lr_scheduler = lr_scheduler_type(
|
||||
self.optimizer, **lr_scheduler_kwargs)
|
||||
self.model.parameters())#, weight_decay=regularizer, **optimizer_kwargs)
|
||||
#self._lr_scheduler = lr_scheduler_type(
|
||||
# self.optimizer, **lr_scheduler_kwargs)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.data_set = PinaDataset(self)
|
||||
# self.data_set = PinaDataset(self)
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
@@ -216,139 +223,131 @@ class PINN(object):
|
||||
# pts = pts.double()
|
||||
self.input_pts[location] = pts
|
||||
|
||||
def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
|
||||
def _residual_loss(self, input_pts, equation):
|
||||
"""
|
||||
Compute the residual loss for a given condition.
|
||||
|
||||
self.model.train()
|
||||
epoch = 0
|
||||
# Add all condition with `input_points` to dataloader
|
||||
for condition in list(set(self.problem.conditions.keys()) - set(self.input_pts.keys())):
|
||||
self.input_pts[condition] = self.problem.conditions[condition]
|
||||
:param torch.Tensor pts: the points to evaluate the residual at.
|
||||
:param Equation equation: the equation to evaluate the residual with.
|
||||
"""
|
||||
|
||||
data_loader = self.data_set.dataloader
|
||||
input_pts = input_pts.to(dtype=self.dtype, device=self.device)
|
||||
input_pts.requires_grad_(True)
|
||||
input_pts.retain_grad()
|
||||
|
||||
header = []
|
||||
for condition_name in self.problem.conditions:
|
||||
condition = self.problem.conditions[condition_name]
|
||||
predicted = self.model(input_pts)
|
||||
residuals = equation.residual(input_pts, predicted)
|
||||
return self._compute_norm(residuals)
|
||||
|
||||
if hasattr(condition, 'function'):
|
||||
if isinstance(condition.function, list):
|
||||
for function in condition.function:
|
||||
header.append(f'{condition_name}{function.__name__}')
|
||||
def _data_loss(self, input_pts, output_pts):
|
||||
"""
|
||||
Compute the residual loss for a given condition.
|
||||
|
||||
:param torch.Tensor pts: the points to evaluate the residual at.
|
||||
:param Equation equation: the equation to evaluate the residual with.
|
||||
"""
|
||||
input_pts = input_pts.to(dtype=self.dtype, device=self.device)
|
||||
output_pts = output_pts.to(dtype=self.dtype, device=self.device)
|
||||
predicted = self.model(input_pts)
|
||||
residuals = predicted - output_pts
|
||||
return self._compute_norm(residuals)
|
||||
|
||||
|
||||
# def closure(self):
|
||||
# """
|
||||
# """
|
||||
# self.optimizer.zero_grad()
|
||||
|
||||
# condition_losses = []
|
||||
# from torch.utils.data import DataLoader
|
||||
# from .utils import MyDataset
|
||||
# loader = DataLoader(
|
||||
# MyDataset(self.input_pts),
|
||||
# batch_size=self.batch_size,
|
||||
# num_workers=1
|
||||
# )
|
||||
# for condition_name in self.problem.conditions:
|
||||
# condition = self.problem.conditions[condition_name]
|
||||
|
||||
# batch_losses = []
|
||||
# for batch in data_loader[condition_name]:
|
||||
|
||||
# if hasattr(condition, 'equation'):
|
||||
# loss = self._residual_loss(
|
||||
# batch[condition_name], condition.equation)
|
||||
# elif hasattr(condition, 'output_points'):
|
||||
# loss = self._data_loss(
|
||||
# batch[condition_name], condition.output_points)
|
||||
|
||||
# batch_losses.append(loss * condition.data_weight)
|
||||
|
||||
# condition_losses.append(sum(batch_losses))
|
||||
|
||||
# loss = sum(condition_losses)
|
||||
# loss.backward()
|
||||
# return loss
|
||||
|
||||
def closure(self):
|
||||
"""
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
losses = []
|
||||
for i, batch in enumerate(self.loader):
|
||||
|
||||
condition_losses = []
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
if samples is None or samples.nelement() == 0:
|
||||
continue
|
||||
|
||||
header.append(f'{condition_name}')
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
if hasattr(condition, 'equation'):
|
||||
loss = self._residual_loss(samples, condition.equation)
|
||||
elif hasattr(condition, 'output_points'):
|
||||
loss = self._data_loss(samples, condition.output_points)
|
||||
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
|
||||
losses.append(sum(condition_losses))
|
||||
|
||||
loss = sum(losses)
|
||||
loss.backward()
|
||||
return losses[0]
|
||||
|
||||
def train(self, stop=100):
|
||||
|
||||
self.model.train()
|
||||
|
||||
############################################################
|
||||
## TODO: move to problem class
|
||||
for condition in list(set(self.problem.conditions.keys()) - set(self.input_pts.keys())):
|
||||
self.input_pts[condition] = self.problem.conditions[condition].input_points
|
||||
|
||||
mydata = self.input_pts
|
||||
|
||||
self.loader = DummyLoader(mydata)
|
||||
|
||||
while True:
|
||||
|
||||
losses = []
|
||||
loss = self.optimizer.step(closure=self.closure)
|
||||
|
||||
for condition_name in self.problem.conditions:
|
||||
condition = self.problem.conditions[condition_name]
|
||||
self.writer.write_loss_in_loop(self, loss)
|
||||
|
||||
for batch in data_loader[condition_name]:
|
||||
|
||||
single_loss = []
|
||||
|
||||
if hasattr(condition, 'function'):
|
||||
pts = batch[condition_name]
|
||||
pts = pts.to(dtype=self.dtype, device=self.device)
|
||||
pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
|
||||
predicted = self.model(pts)
|
||||
for function in condition.function:
|
||||
residuals = function(pts, predicted)
|
||||
local_loss = (
|
||||
condition.data_weight*self._compute_norm(
|
||||
residuals))
|
||||
single_loss.append(local_loss)
|
||||
elif hasattr(condition, 'output_points'):
|
||||
pts = condition.input_points.to(
|
||||
dtype=self.dtype, device=self.device)
|
||||
predicted = self.model(pts)
|
||||
residuals = predicted - \
|
||||
condition.output_points.to(
|
||||
device=self.device, dtype=self.dtype) # TODO fix
|
||||
local_loss = (
|
||||
condition.data_weight*self._compute_norm(residuals))
|
||||
single_loss.append(local_loss)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
sum(single_loss).backward()
|
||||
self.optimizer.step()
|
||||
|
||||
losses.append(sum(single_loss))
|
||||
|
||||
self._lr_scheduler.step()
|
||||
|
||||
if save_loss and (epoch % save_loss == 0 or epoch == 0):
|
||||
self.history_loss[epoch] = [
|
||||
loss.detach().item() for loss in losses]
|
||||
|
||||
if trial:
|
||||
import optuna
|
||||
trial.report(sum(losses), epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
#self._lr_scheduler.step()
|
||||
|
||||
if isinstance(stop, int):
|
||||
if epoch == stop:
|
||||
print('[epoch {:05d}] {:.6e} '.format(
|
||||
self.trained_epoch, sum(losses).item()), end='')
|
||||
for loss in losses:
|
||||
print('{:.6e} '.format(loss.item()), end='')
|
||||
print()
|
||||
if self.trained_epoch == stop:
|
||||
break
|
||||
elif isinstance(stop, float):
|
||||
if sum(losses) < stop:
|
||||
if loss.item() < stop:
|
||||
break
|
||||
|
||||
if epoch % frequency_print == 0 or epoch == 1:
|
||||
print(' {:5s} {:12s} '.format('', 'sum'), end='')
|
||||
for name in header:
|
||||
print('{:12.12s} '.format(name), end='')
|
||||
print()
|
||||
|
||||
print('[epoch {:05d}] {:.6e} '.format(
|
||||
self.trained_epoch, sum(losses).item()), end='')
|
||||
for loss in losses:
|
||||
print('{:.6e} '.format(loss.item()), end='')
|
||||
print()
|
||||
|
||||
self.trained_epoch += 1
|
||||
epoch += 1
|
||||
|
||||
self.model.eval()
|
||||
|
||||
return sum(losses).item()
|
||||
|
||||
# def error(self, dtype='l2', res=100):
|
||||
|
||||
# import numpy as np
|
||||
# if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
|
||||
# pts_container = []
|
||||
# for mn, mx in self.problem.domain_bound:
|
||||
# pts_container.append(np.linspace(mn, mx, res))
|
||||
# grids_container = np.meshgrid(*pts_container)
|
||||
# Z_true = self.problem.truth_solution(*grids_container)
|
||||
|
||||
# elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
|
||||
# grids_container = self.problem.data_solution['grid']
|
||||
# Z_true = self.problem.data_solution['grid_solution']
|
||||
# try:
|
||||
# unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
|
||||
# dtype=self.dtype, device=self.device)
|
||||
# Z_pred = self.model(unrolled_pts)
|
||||
# Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
|
||||
|
||||
# if dtype == 'l2':
|
||||
# return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true)
|
||||
# else:
|
||||
# # TODO H1
|
||||
# pass
|
||||
# except:
|
||||
# print("")
|
||||
# print("Something went wrong...")
|
||||
# print(
|
||||
# "Not able to compute the error. Please pass a data solution or a true solution")
|
||||
self.model.eval()
|
||||
179
pina/utils.py
179
pina/utils.py
@@ -98,63 +98,146 @@ def is_function(f):
|
||||
return type(f) == types.FunctionType or type(f) == types.LambdaType
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
def chebyshev_roots(n):
|
||||
"""
|
||||
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
|
||||
|
||||
def __init__(self, pinn) -> None:
|
||||
self.pinn = pinn
|
||||
:param int n: number of roots
|
||||
:return: roots
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
k = torch.arange(n)
|
||||
nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
|
||||
return nodes
|
||||
|
||||
@property
|
||||
def dataloader(self):
|
||||
return self._create_dataloader()
|
||||
# class PinaDataset():
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return [self.SampleDataset(key, val)
|
||||
for key, val in self.input_pts.items()]
|
||||
# def __init__(self, pinn) -> None:
|
||||
# self.pinn = pinn
|
||||
|
||||
def _create_dataloader(self):
|
||||
"""Private method for creating dataloader
|
||||
# @property
|
||||
# def dataloader(self):
|
||||
# return self._create_dataloader()
|
||||
|
||||
:return: dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
if self.pinn.batch_size is None:
|
||||
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
# @property
|
||||
# def dataset(self):
|
||||
# return [self.SampleDataset(key, val)
|
||||
# for key, val in self.input_pts.items()]
|
||||
|
||||
def custom_collate(batch):
|
||||
# extracting pts labels
|
||||
_, pts = list(batch[0].items())[0]
|
||||
labels = pts.labels
|
||||
# calling default torch collate
|
||||
collate_res = default_collate(batch)
|
||||
# save collate result in dict
|
||||
res = {}
|
||||
for key, val in collate_res.items():
|
||||
val.labels = labels
|
||||
res[key] = val
|
||||
return res
|
||||
# def _create_dataloader(self):
|
||||
# """Private method for creating dataloader
|
||||
|
||||
# creating dataset, list of dataset for each location
|
||||
datasets = [self.SampleDataset(key, val)
|
||||
for key, val in self.pinn.input_pts.items()]
|
||||
# creating dataloader
|
||||
dataloaders = [DataLoader(dataset=dat,
|
||||
batch_size=self.pinn.batch_size,
|
||||
collate_fn=custom_collate)
|
||||
for dat in datasets]
|
||||
# :return: dataloader
|
||||
# :rtype: torch.utils.data.DataLoader
|
||||
# """
|
||||
# if self.pinn.batch_size is None:
|
||||
# return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
|
||||
return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
# def custom_collate(batch):
|
||||
# # extracting pts labels
|
||||
# _, pts = list(batch[0].items())[0]
|
||||
# labels = pts.labels
|
||||
# # calling default torch collate
|
||||
# collate_res = default_collate(batch)
|
||||
# # save collate result in dict
|
||||
# res = {}
|
||||
# for key, val in collate_res.items():
|
||||
# val.labels = labels
|
||||
# res[key] = val
|
||||
# __init__(self, location, tensor):
|
||||
# self._tensor = tensor
|
||||
# self._location = location
|
||||
# self._len = len(tensor)
|
||||
|
||||
class SampleDataset(torch.utils.data.Dataset):
|
||||
# def __getitem__(self, index):
|
||||
# tensor = self._tensor.select(0, index)
|
||||
# return {self._location: tensor}
|
||||
|
||||
def __init__(self, location, tensor):
|
||||
self._tensor = tensor
|
||||
self._location = location
|
||||
self._len = len(tensor)
|
||||
# def __len__(self):
|
||||
# return self._len
|
||||
|
||||
def __getitem__(self, index):
|
||||
tensor = self._tensor.select(0, index)
|
||||
return {self._location: tensor}
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
for label in self.labels:
|
||||
sample_tensor = getattr(self, label)
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
# print('porcodio')
|
||||
# print(sample_tensor.shape[1])
|
||||
# print(index)
|
||||
# print(sample_tensor[index])
|
||||
try:
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return max([len(getattr(self, label)) for label in self.labels])
|
||||
|
||||
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
|
||||
# class SampleDataset(torch.utils.data.Dataset):
|
||||
|
||||
# def __init__(self, location, tensor):
|
||||
# self._tensor = tensor
|
||||
# self._location = location
|
||||
# self._len = len(tensor)
|
||||
|
||||
# def __getitem__(self, index):
|
||||
# tensor = self._tensor.select(0, index)
|
||||
# return {self._location: tensor}
|
||||
|
||||
# def __len__(self):
|
||||
# return self._len
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
for label in self.labels:
|
||||
sample_tensor = getattr(self, label)
|
||||
|
||||
# print('porcodio')
|
||||
# print(sample_tensor.shape[1])
|
||||
# print(index)
|
||||
# print(sample_tensor[index])
|
||||
try:
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return max([len(getattr(self, label)) for label in self.labels])
|
||||
|
||||
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
54
pina/writer.py
Normal file
54
pina/writer.py
Normal file
@@ -0,0 +1,54 @@
|
||||
""" Module for plotting. """
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pina import LabelTensor
|
||||
|
||||
|
||||
class Writer:
|
||||
"""
|
||||
Implementation of a writer class, for textual output.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
frequency_print=10,
|
||||
header='any') -> None:
|
||||
|
||||
"""
|
||||
The constructor of the class.
|
||||
|
||||
:param int frequency_print: the frequency in epochs of printing.
|
||||
:param ['any', 'begin', 'none'] header: the header of the output.
|
||||
"""
|
||||
|
||||
self._frequency_print = frequency_print
|
||||
self._header = header
|
||||
|
||||
|
||||
def header(self, trainer):
|
||||
"""
|
||||
The method for printing the header.
|
||||
"""
|
||||
header = []
|
||||
for condition_name in trainer.problem.conditions:
|
||||
header.append(f'{condition_name}')
|
||||
|
||||
return header
|
||||
|
||||
def write_loss(self, trainer):
|
||||
"""
|
||||
The method for writing the output.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def write_loss_in_loop(self, trainer, loss):
|
||||
"""
|
||||
The method for writing the output within the training loop.
|
||||
|
||||
:param pina.trainer.Trainer trainer: the trainer object.
|
||||
"""
|
||||
|
||||
if trainer.trained_epoch % self._frequency_print == 0:
|
||||
print(f'Epoch {trainer.trained_epoch:05d}: {loss.item():.5e}')
|
||||
Reference in New Issue
Block a user