committed by
Nicola Demo
parent
f0d68b34c7
commit
30f865d912
@@ -49,11 +49,19 @@ class Stokes(SpatialProblem):
|
|||||||
value = 0.0
|
value = 0.0
|
||||||
return output_.extract(['ux', 'uy']) - value
|
return output_.extract(['ux', 'uy']) - value
|
||||||
|
|
||||||
|
domains = {
|
||||||
|
'gamma_top': CartesianDomain({'x': [-2, 2], 'y': 1}),
|
||||||
|
'gamma_bot': CartesianDomain({'x': [-2, 2], 'y': -1}),
|
||||||
|
'gamma_out': CartesianDomain({'x': 2, 'y': [-1, 1]}),
|
||||||
|
'gamma_in': CartesianDomain({'x': -2, 'y': [-1, 1]}),
|
||||||
|
'D': CartesianDomain({'x': [-2, 2], 'y': [-1, 1]})
|
||||||
|
}
|
||||||
|
|
||||||
# problem condition statement
|
# problem condition statement
|
||||||
conditions = {
|
conditions = {
|
||||||
'gamma_top': Condition(location=CartesianDomain({'x': [-2, 2], 'y': 1}), equation=Equation(wall)),
|
'gamma_top': Condition(domain='gamma_top', equation=Equation(wall)),
|
||||||
'gamma_bot': Condition(location=CartesianDomain({'x': [-2, 2], 'y': -1}), equation=Equation(wall)),
|
'gamma_bot': Condition(domain='gamma_bot', equation=Equation(wall)),
|
||||||
'gamma_out': Condition(location=CartesianDomain({'x': 2, 'y': [-1, 1]}), equation=Equation(outlet)),
|
'gamma_out': Condition(domain='gamma_out', equation=Equation(outlet)),
|
||||||
'gamma_in': Condition(location=CartesianDomain({'x': -2, 'y': [-1, 1]}), equation=Equation(inlet)),
|
'gamma_in': Condition(domain='gamma_in', equation=Equation(inlet)),
|
||||||
'D': Condition(location=CartesianDomain({'x': [-2, 2], 'y': [-1, 1]}), equation=SystemEquation([momentum, continuity]))
|
'D': Condition(domain='D', equation=SystemEquation([momentum, continuity]))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# create problem and discretise domain
|
# create problem and discretise domain
|
||||||
stokes_problem = Stokes()
|
stokes_problem = Stokes()
|
||||||
stokes_problem.discretise_domain(n=1000, locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
stokes_problem.discretise_domain(n=1000, domains=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
|
||||||
stokes_problem.discretise_domain(n=2000, locations=['D'])
|
stokes_problem.discretise_domain(n=2000, domains=['D'])
|
||||||
|
|
||||||
# make the model
|
# make the model
|
||||||
model = FeedForward(
|
model = FeedForward(
|
||||||
|
|||||||
@@ -84,14 +84,15 @@ class Condition:
|
|||||||
return DomainEquationCondition(**kwargs)
|
return DomainEquationCondition(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||||
|
# TODO: remove, not used anymore
|
||||||
|
'''
|
||||||
if (
|
if (
|
||||||
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
||||||
and sorted(kwargs.keys()) != sorted(["location", "equation"])
|
and sorted(kwargs.keys()) != sorted(["location", "equation"])
|
||||||
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
|
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
|
||||||
):
|
):
|
||||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||||
|
# TODO: remove, not used anymore
|
||||||
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
|
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
|
||||||
raise TypeError("`input_points` must be a torch.Tensor.")
|
raise TypeError("`input_points` must be a torch.Tensor.")
|
||||||
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
|
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
|
||||||
@@ -103,3 +104,4 @@ class Condition:
|
|||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
'''
|
||||||
@@ -16,3 +16,6 @@ class ConditionInterface(metaclass=ABCMeta):
|
|||||||
:return: The residual of the condition.
|
:return: The residual of the condition.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def set_problem(self, problem):
|
||||||
|
self._problem = problem
|
||||||
|
|||||||
@@ -15,6 +15,12 @@ class DomainEquationCondition(ConditionInterface):
|
|||||||
self.domain = domain
|
self.domain = domain
|
||||||
self.equation = equation
|
self.equation = equation
|
||||||
|
|
||||||
|
def residual(self, model):
|
||||||
|
"""
|
||||||
|
Compute the residual of the condition.
|
||||||
|
"""
|
||||||
|
self.batch_residual(model, self.domain, self.equation)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batch_residual(model, input_pts, equation):
|
def batch_residual(model, input_pts, equation):
|
||||||
"""
|
"""
|
||||||
@@ -22,7 +28,7 @@ class DomainEquationCondition(ConditionInterface):
|
|||||||
output points are provided as arguments.
|
output points are provided as arguments.
|
||||||
|
|
||||||
:param torch.nn.Module model: The model to evaluate the condition.
|
:param torch.nn.Module model: The model to evaluate the condition.
|
||||||
:param torch.Tensor input_points: The input points.
|
:param torch.Tensor input_pts: The input points.
|
||||||
:param torch.Tensor output_points: The output points.
|
:param torch.Tensor equation: The output points.
|
||||||
"""
|
"""
|
||||||
return equation.residual(model(input_pts))
|
return equation.residual(input_pts, model(input_pts))
|
||||||
@@ -40,4 +40,5 @@ class DomainOutputCondition(ConditionInterface):
|
|||||||
:param torch.Tensor input_points: The input points.
|
:param torch.Tensor input_points: The input points.
|
||||||
:param torch.Tensor output_points: The output points.
|
:param torch.Tensor output_points: The output points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return output_points - model(input_points)
|
return output_points - model(input_points)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import torch
|
||||||
|
|
||||||
from .domain_interface import DomainInterface
|
from .domain_interface import DomainInterface
|
||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# class LabelTensor(torch.Tensor):
|
# class LabelTensor(torch.Tensor):
|
||||||
# """Torch tensor with a label for any column."""
|
# """Torch tensor with a label for any column."""
|
||||||
|
|
||||||
@@ -307,13 +306,13 @@ from torch import Tensor
|
|||||||
# s = "no labels\n"
|
# s = "no labels\n"
|
||||||
# s += super().__str__()
|
# s += super().__str__()
|
||||||
# return s
|
# return s
|
||||||
|
|
||||||
def issubset(a, b):
|
def issubset(a, b):
|
||||||
"""
|
"""
|
||||||
Check if a is a subset of b.
|
Check if a is a subset of b.
|
||||||
"""
|
"""
|
||||||
return set(a).issubset(set(b))
|
return set(a).issubset(set(b))
|
||||||
|
|
||||||
|
|
||||||
class LabelTensor(torch.Tensor):
|
class LabelTensor(torch.Tensor):
|
||||||
"""Torch tensor with a label for any column."""
|
"""Torch tensor with a label for any column."""
|
||||||
|
|
||||||
@@ -403,6 +402,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
return LabelTensor(new_tensor, label_to_extract)
|
return LabelTensor(new_tensor, label_to_extract)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""
|
||||||
|
returns a string with the representation of the class
|
||||||
|
"""
|
||||||
|
|
||||||
s = ''
|
s = ''
|
||||||
for key, value in self.labels.items():
|
for key, value in self.labels.items():
|
||||||
s += f"{key}: {value}\n"
|
s += f"{key}: {value}\n"
|
||||||
@@ -432,3 +435,31 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return super().dtype
|
return super().dtype
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Performs Tensor dtype and/or device conversion. For more details, see
|
||||||
|
:meth:`torch.Tensor.to`.
|
||||||
|
"""
|
||||||
|
tmp = super().to(*args, **kwargs)
|
||||||
|
new = self.__class__.clone(self)
|
||||||
|
new.data = tmp.data
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
def clone(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Clone the LabelTensor. For more details, see
|
||||||
|
:meth:`torch.Tensor.clone`.
|
||||||
|
|
||||||
|
:return: A copy of the tensor.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# # used before merging
|
||||||
|
# try:
|
||||||
|
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||||
|
# except:
|
||||||
|
# out = super().clone(*args, **kwargs)
|
||||||
|
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||||
|
return out
|
||||||
@@ -20,7 +20,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
|
|
||||||
self._discretized_domains = {}
|
self._discretized_domains = {}
|
||||||
|
|
||||||
for name, domain in self.domains.items():
|
for name, domain in self.domains.items():
|
||||||
@@ -28,18 +27,19 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
self._discretized_domains[name] = domain
|
self._discretized_domains[name] = domain
|
||||||
|
|
||||||
for condition_name in self.conditions:
|
for condition_name in self.conditions:
|
||||||
self.conditions[condition_name]._problem = self
|
self.conditions[condition_name].set_problem(self)
|
||||||
|
|
||||||
# # variable storing all points
|
# # variable storing all points
|
||||||
# self.input_pts = {}
|
self.input_pts = {}
|
||||||
|
|
||||||
# # varible to check if sampling is done. If no location
|
# # varible to check if sampling is done. If no location
|
||||||
# # element is presented in Condition this variable is set to true
|
# # element is presented in Condition this variable is set to true
|
||||||
# self._have_sampled_points = {}
|
# self._have_sampled_points = {}
|
||||||
# for condition_name in self.conditions:
|
for condition_name in self.conditions:
|
||||||
# self._have_sampled_points[condition_name] = False
|
self._discretized_domains[condition_name] = False
|
||||||
|
|
||||||
# # put in self.input_pts all the points that we don't need to sample
|
# # put in self.input_pts all the points that we don't need to sample
|
||||||
# self._span_condition_points()
|
self._span_condition_points()
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
"""
|
"""
|
||||||
@@ -125,7 +125,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
if hasattr(condition, "input_points"):
|
if hasattr(condition, "input_points"):
|
||||||
samples = condition.input_points
|
samples = condition.input_points
|
||||||
self.input_pts[condition_name] = samples
|
self.input_pts[condition_name] = samples
|
||||||
self._have_sampled_points[condition_name] = True
|
self._discretized_domains[condition_name] = True
|
||||||
if hasattr(self, "unknown_parameter_domain"):
|
if hasattr(self, "unknown_parameter_domain"):
|
||||||
# initialize the unknown parameters of the inverse problem given
|
# initialize the unknown parameters of the inverse problem given
|
||||||
# the domain the user gives
|
# the domain the user gives
|
||||||
@@ -141,7 +141,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def discretise_domain(
|
def discretise_domain(
|
||||||
self, n, mode="random", variables="all", locations="all"
|
self, n, mode="random", variables="all", domains="all"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate a set of points to span the `Location` of all the conditions of
|
Generate a set of points to span the `Location` of all the conditions of
|
||||||
@@ -192,31 +192,37 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
f"should be in {self.input_variables}.",
|
f"should be in {self.input_variables}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# check consistency location
|
# # check consistency location # TODO: check if this is needed (from 0.1)
|
||||||
locations_to_sample = [
|
# locations_to_sample = [
|
||||||
condition
|
# condition
|
||||||
for condition in self.conditions
|
# for condition in self.conditions
|
||||||
if hasattr(self.conditions[condition], "location")
|
# if hasattr(self.conditions[condition], "location")
|
||||||
]
|
# ]
|
||||||
if locations == "all":
|
# if locations == "all":
|
||||||
# only locations that can be sampled
|
# # only locations that can be sampled
|
||||||
locations = locations_to_sample
|
# locations = locations_to_sample
|
||||||
else:
|
# else:
|
||||||
check_consistency(locations, str)
|
# check_consistency(locations, str)
|
||||||
|
|
||||||
if sorted(locations) != sorted(locations_to_sample):
|
# if sorted(locations) != sorted(locations_to_sample):
|
||||||
|
if domains == "all":
|
||||||
|
domains = [condition for condition in self.conditions]
|
||||||
|
else:
|
||||||
|
check_consistency(domains, str)
|
||||||
|
print(domains)
|
||||||
|
if sorted(domains) != sorted(self.conditions):
|
||||||
TypeError(
|
TypeError(
|
||||||
f"Wrong locations for sampling. Location ",
|
f"Wrong locations for sampling. Location ",
|
||||||
f"should be in {locations_to_sample}.",
|
f"should be in {locations_to_sample}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# sampling
|
# sampling
|
||||||
for location in locations:
|
for d in domains:
|
||||||
condition = self.conditions[location]
|
condition = self.conditions[d]
|
||||||
|
|
||||||
# we try to check if we have already sampled
|
# we try to check if we have already sampled
|
||||||
try:
|
try:
|
||||||
already_sampled = [self.input_pts[location]]
|
already_sampled = [self.input_pts[d]]
|
||||||
# if we have not sampled, a key error is thrown
|
# if we have not sampled, a key error is thrown
|
||||||
except KeyError:
|
except KeyError:
|
||||||
already_sampled = []
|
already_sampled = []
|
||||||
@@ -225,25 +231,27 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
# but we want to sample again we set already_sampled
|
# but we want to sample again we set already_sampled
|
||||||
# to an empty list since we need to sample again, and
|
# to an empty list since we need to sample again, and
|
||||||
# self._have_sampled_points to False.
|
# self._have_sampled_points to False.
|
||||||
if self._have_sampled_points[location]:
|
if self._discretized_domains[d]:
|
||||||
already_sampled = []
|
already_sampled = []
|
||||||
self._have_sampled_points[location] = False
|
self._discretized_domains[d] = False
|
||||||
|
print(condition.domain)
|
||||||
|
print(d)
|
||||||
# build samples
|
# build samples
|
||||||
samples = [
|
samples = [
|
||||||
condition.location.sample(n=n, mode=mode, variables=variables)
|
self.domains[d].sample(n=n, mode=mode, variables=variables)
|
||||||
] + already_sampled
|
] + already_sampled
|
||||||
pts = merge_tensors(samples)
|
pts = merge_tensors(samples)
|
||||||
self.input_pts[location] = pts
|
self.input_pts[d] = pts
|
||||||
|
|
||||||
# the condition is sampled if input_pts contains all labels
|
# the condition is sampled if input_pts contains all labels
|
||||||
if sorted(self.input_pts[location].labels) == sorted(
|
if sorted(self.input_pts[d].labels) == sorted(
|
||||||
self.input_variables
|
self.input_variables
|
||||||
):
|
):
|
||||||
self._have_sampled_points[location] = True
|
# self._have_sampled_points[location] = True
|
||||||
self.input_pts[location] = self.input_pts[location].extract(
|
# self.input_pts[location] = self.input_pts[location].extract(
|
||||||
sorted(self.input_variables)
|
# sorted(self.input_variables)
|
||||||
)
|
# )
|
||||||
|
self._have_sampled_points[d] = True
|
||||||
|
|
||||||
def add_points(self, new_points):
|
def add_points(self, new_points):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -134,8 +134,6 @@ class SupervisedSolver(SolverInterface):
|
|||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
pts = batch.input
|
pts = batch.input
|
||||||
out = batch.output
|
out = batch.output
|
||||||
print(out)
|
|
||||||
print(pts)
|
|
||||||
|
|
||||||
if condition_name not in self.problem.conditions:
|
if condition_name not in self.problem.conditions:
|
||||||
raise RuntimeError("Something wrong happened.")
|
raise RuntimeError("Something wrong happened.")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pina import Condition, LabelTensor
|
|||||||
from pina.solvers import SupervisedSolver
|
from pina.solvers import SupervisedSolver
|
||||||
from pina.trainer import Trainer
|
from pina.trainer import Trainer
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
from pina.loss.loss_interface import LpLoss
|
from pina.loss import LpLoss
|
||||||
|
|
||||||
|
|
||||||
class NeuralOperatorProblem(AbstractProblem):
|
class NeuralOperatorProblem(AbstractProblem):
|
||||||
@@ -94,11 +94,9 @@ class GraphModel(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def test_graph():
|
def test_graph():
|
||||||
|
|
||||||
solver = AutoSolver(problem = problem, model=GraphModel(2, 1), loss=LpLoss())
|
solver = AutoSolver(problem = problem, model=GraphModel(2, 1), loss=LpLoss())
|
||||||
trainer = Trainer(solver=solver, max_epochs=30, accelerator='cpu', batch_size=20)
|
trainer = Trainer(solver=solver, max_epochs=30, accelerator='cpu', batch_size=20)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_cpu():
|
def test_train_cpu():
|
||||||
@@ -107,6 +105,7 @@ def test_train_cpu():
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def test_train_restore():
|
# def test_train_restore():
|
||||||
# tmpdir = "tests/tmp_restore"
|
# tmpdir = "tests/tmp_restore"
|
||||||
# solver = SupervisedSolver(problem=problem,
|
# solver = SupervisedSolver(problem=problem,
|
||||||
|
|||||||
Reference in New Issue
Block a user