Fix bugs in 0.2 (#344)

* Fix some bugs
This commit is contained in:
FilippoOlivo
2024-09-12 18:12:59 +02:00
committed by Nicola Demo
parent f0d68b34c7
commit 30f865d912
11 changed files with 112 additions and 55 deletions

View File

@@ -49,11 +49,19 @@ class Stokes(SpatialProblem):
value = 0.0 value = 0.0
return output_.extract(['ux', 'uy']) - value return output_.extract(['ux', 'uy']) - value
domains = {
'gamma_top': CartesianDomain({'x': [-2, 2], 'y': 1}),
'gamma_bot': CartesianDomain({'x': [-2, 2], 'y': -1}),
'gamma_out': CartesianDomain({'x': 2, 'y': [-1, 1]}),
'gamma_in': CartesianDomain({'x': -2, 'y': [-1, 1]}),
'D': CartesianDomain({'x': [-2, 2], 'y': [-1, 1]})
}
# problem condition statement # problem condition statement
conditions = { conditions = {
'gamma_top': Condition(location=CartesianDomain({'x': [-2, 2], 'y': 1}), equation=Equation(wall)), 'gamma_top': Condition(domain='gamma_top', equation=Equation(wall)),
'gamma_bot': Condition(location=CartesianDomain({'x': [-2, 2], 'y': -1}), equation=Equation(wall)), 'gamma_bot': Condition(domain='gamma_bot', equation=Equation(wall)),
'gamma_out': Condition(location=CartesianDomain({'x': 2, 'y': [-1, 1]}), equation=Equation(outlet)), 'gamma_out': Condition(domain='gamma_out', equation=Equation(outlet)),
'gamma_in': Condition(location=CartesianDomain({'x': -2, 'y': [-1, 1]}), equation=Equation(inlet)), 'gamma_in': Condition(domain='gamma_in', equation=Equation(inlet)),
'D': Condition(location=CartesianDomain({'x': [-2, 2], 'y': [-1, 1]}), equation=SystemEquation([momentum, continuity])) 'D': Condition(domain='D', equation=SystemEquation([momentum, continuity]))
} }

View File

@@ -17,8 +17,8 @@ if __name__ == "__main__":
# create problem and discretise domain # create problem and discretise domain
stokes_problem = Stokes() stokes_problem = Stokes()
stokes_problem.discretise_domain(n=1000, locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out']) stokes_problem.discretise_domain(n=1000, domains=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
stokes_problem.discretise_domain(n=2000, locations=['D']) stokes_problem.discretise_domain(n=2000, domains=['D'])
# make the model # make the model
model = FeedForward( model = FeedForward(

View File

@@ -84,14 +84,15 @@ class Condition:
return DomainEquationCondition(**kwargs) return DomainEquationCondition(**kwargs)
else: else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
'''
if ( if (
sorted(kwargs.keys()) != sorted(["input_points", "output_points"]) sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
and sorted(kwargs.keys()) != sorted(["location", "equation"]) and sorted(kwargs.keys()) != sorted(["location", "equation"])
and sorted(kwargs.keys()) != sorted(["input_points", "equation"]) and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
): ):
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor): if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
raise TypeError("`input_points` must be a torch.Tensor.") raise TypeError("`input_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor): if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
@@ -103,3 +104,4 @@ class Condition:
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self, key, value) setattr(self, key, value)
'''

View File

@@ -16,3 +16,6 @@ class ConditionInterface(metaclass=ABCMeta):
:return: The residual of the condition. :return: The residual of the condition.
""" """
pass pass
def set_problem(self, problem):
self._problem = problem

View File

@@ -15,6 +15,12 @@ class DomainEquationCondition(ConditionInterface):
self.domain = domain self.domain = domain
self.equation = equation self.equation = equation
def residual(self, model):
"""
Compute the residual of the condition.
"""
self.batch_residual(model, self.domain, self.equation)
@staticmethod @staticmethod
def batch_residual(model, input_pts, equation): def batch_residual(model, input_pts, equation):
""" """
@@ -22,7 +28,7 @@ class DomainEquationCondition(ConditionInterface):
output points are provided as arguments. output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition. :param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_points: The input points. :param torch.Tensor input_pts: The input points.
:param torch.Tensor output_points: The output points. :param torch.Tensor equation: The output points.
""" """
return equation.residual(model(input_pts)) return equation.residual(input_pts, model(input_pts))

View File

@@ -40,4 +40,5 @@ class DomainOutputCondition(ConditionInterface):
:param torch.Tensor input_points: The input points. :param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points. :param torch.Tensor output_points: The output points.
""" """
return output_points - model(input_points) return output_points - model(input_points)

View File

@@ -1,4 +1,5 @@
import torch import torch
import torch
from .domain_interface import DomainInterface from .domain_interface import DomainInterface
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor

View File

@@ -5,7 +5,6 @@ import torch
from torch import Tensor from torch import Tensor
# class LabelTensor(torch.Tensor): # class LabelTensor(torch.Tensor):
# """Torch tensor with a label for any column.""" # """Torch tensor with a label for any column."""
@@ -307,13 +306,13 @@ from torch import Tensor
# s = "no labels\n" # s = "no labels\n"
# s += super().__str__() # s += super().__str__()
# return s # return s
def issubset(a, b): def issubset(a, b):
""" """
Check if a is a subset of b. Check if a is a subset of b.
""" """
return set(a).issubset(set(b)) return set(a).issubset(set(b))
class LabelTensor(torch.Tensor): class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column.""" """Torch tensor with a label for any column."""
@@ -403,6 +402,10 @@ class LabelTensor(torch.Tensor):
return LabelTensor(new_tensor, label_to_extract) return LabelTensor(new_tensor, label_to_extract)
def __str__(self): def __str__(self):
"""
returns a string with the representation of the class
"""
s = '' s = ''
for key, value in self.labels.items(): for key, value in self.labels.items():
s += f"{key}: {value}\n" s += f"{key}: {value}\n"
@@ -432,3 +435,31 @@ class LabelTensor(torch.Tensor):
@property @property
def dtype(self): def dtype(self):
return super().dtype return super().dtype
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
"""
tmp = super().to(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
:meth:`torch.Tensor.clone`.
:return: A copy of the tensor.
:rtype: LabelTensor
"""
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out

View File

@@ -20,7 +20,6 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self): def __init__(self):
self._discretized_domains = {} self._discretized_domains = {}
for name, domain in self.domains.items(): for name, domain in self.domains.items():
@@ -28,18 +27,19 @@ class AbstractProblem(metaclass=ABCMeta):
self._discretized_domains[name] = domain self._discretized_domains[name] = domain
for condition_name in self.conditions: for condition_name in self.conditions:
self.conditions[condition_name]._problem = self self.conditions[condition_name].set_problem(self)
# # variable storing all points # # variable storing all points
# self.input_pts = {} self.input_pts = {}
# # varible to check if sampling is done. If no location # # varible to check if sampling is done. If no location
# # element is presented in Condition this variable is set to true # # element is presented in Condition this variable is set to true
# self._have_sampled_points = {} # self._have_sampled_points = {}
# for condition_name in self.conditions: for condition_name in self.conditions:
# self._have_sampled_points[condition_name] = False self._discretized_domains[condition_name] = False
# # put in self.input_pts all the points that we don't need to sample # # put in self.input_pts all the points that we don't need to sample
# self._span_condition_points() self._span_condition_points()
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
""" """
@@ -125,7 +125,7 @@ class AbstractProblem(metaclass=ABCMeta):
if hasattr(condition, "input_points"): if hasattr(condition, "input_points"):
samples = condition.input_points samples = condition.input_points
self.input_pts[condition_name] = samples self.input_pts[condition_name] = samples
self._have_sampled_points[condition_name] = True self._discretized_domains[condition_name] = True
if hasattr(self, "unknown_parameter_domain"): if hasattr(self, "unknown_parameter_domain"):
# initialize the unknown parameters of the inverse problem given # initialize the unknown parameters of the inverse problem given
# the domain the user gives # the domain the user gives
@@ -141,7 +141,7 @@ class AbstractProblem(metaclass=ABCMeta):
) )
def discretise_domain( def discretise_domain(
self, n, mode="random", variables="all", locations="all" self, n, mode="random", variables="all", domains="all"
): ):
""" """
Generate a set of points to span the `Location` of all the conditions of Generate a set of points to span the `Location` of all the conditions of
@@ -192,31 +192,37 @@ class AbstractProblem(metaclass=ABCMeta):
f"should be in {self.input_variables}.", f"should be in {self.input_variables}.",
) )
# check consistency location # # check consistency location # TODO: check if this is needed (from 0.1)
locations_to_sample = [ # locations_to_sample = [
condition # condition
for condition in self.conditions # for condition in self.conditions
if hasattr(self.conditions[condition], "location") # if hasattr(self.conditions[condition], "location")
] # ]
if locations == "all": # if locations == "all":
# only locations that can be sampled # # only locations that can be sampled
locations = locations_to_sample # locations = locations_to_sample
else: # else:
check_consistency(locations, str) # check_consistency(locations, str)
if sorted(locations) != sorted(locations_to_sample): # if sorted(locations) != sorted(locations_to_sample):
if domains == "all":
domains = [condition for condition in self.conditions]
else:
check_consistency(domains, str)
print(domains)
if sorted(domains) != sorted(self.conditions):
TypeError( TypeError(
f"Wrong locations for sampling. Location ", f"Wrong locations for sampling. Location ",
f"should be in {locations_to_sample}.", f"should be in {locations_to_sample}.",
) )
# sampling # sampling
for location in locations: for d in domains:
condition = self.conditions[location] condition = self.conditions[d]
# we try to check if we have already sampled # we try to check if we have already sampled
try: try:
already_sampled = [self.input_pts[location]] already_sampled = [self.input_pts[d]]
# if we have not sampled, a key error is thrown # if we have not sampled, a key error is thrown
except KeyError: except KeyError:
already_sampled = [] already_sampled = []
@@ -225,25 +231,27 @@ class AbstractProblem(metaclass=ABCMeta):
# but we want to sample again we set already_sampled # but we want to sample again we set already_sampled
# to an empty list since we need to sample again, and # to an empty list since we need to sample again, and
# self._have_sampled_points to False. # self._have_sampled_points to False.
if self._have_sampled_points[location]: if self._discretized_domains[d]:
already_sampled = [] already_sampled = []
self._have_sampled_points[location] = False self._discretized_domains[d] = False
print(condition.domain)
print(d)
# build samples # build samples
samples = [ samples = [
condition.location.sample(n=n, mode=mode, variables=variables) self.domains[d].sample(n=n, mode=mode, variables=variables)
] + already_sampled ] + already_sampled
pts = merge_tensors(samples) pts = merge_tensors(samples)
self.input_pts[location] = pts self.input_pts[d] = pts
# the condition is sampled if input_pts contains all labels # the condition is sampled if input_pts contains all labels
if sorted(self.input_pts[location].labels) == sorted( if sorted(self.input_pts[d].labels) == sorted(
self.input_variables self.input_variables
): ):
self._have_sampled_points[location] = True # self._have_sampled_points[location] = True
self.input_pts[location] = self.input_pts[location].extract( # self.input_pts[location] = self.input_pts[location].extract(
sorted(self.input_variables) # sorted(self.input_variables)
) # )
self._have_sampled_points[d] = True
def add_points(self, new_points): def add_points(self, new_points):
""" """

View File

@@ -134,8 +134,6 @@ class SupervisedSolver(SolverInterface):
condition = self.problem.conditions[condition_name] condition = self.problem.conditions[condition_name]
pts = batch.input pts = batch.input
out = batch.output out = batch.output
print(out)
print(pts)
if condition_name not in self.problem.conditions: if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.") raise RuntimeError("Something wrong happened.")

View File

@@ -5,7 +5,7 @@ from pina import Condition, LabelTensor
from pina.solvers import SupervisedSolver from pina.solvers import SupervisedSolver
from pina.trainer import Trainer from pina.trainer import Trainer
from pina.model import FeedForward from pina.model import FeedForward
from pina.loss.loss_interface import LpLoss from pina.loss import LpLoss
class NeuralOperatorProblem(AbstractProblem): class NeuralOperatorProblem(AbstractProblem):
@@ -94,11 +94,9 @@ class GraphModel(torch.nn.Module):
return x return x
def test_graph(): def test_graph():
solver = AutoSolver(problem = problem, model=GraphModel(2, 1), loss=LpLoss()) solver = AutoSolver(problem = problem, model=GraphModel(2, 1), loss=LpLoss())
trainer = Trainer(solver=solver, max_epochs=30, accelerator='cpu', batch_size=20) trainer = Trainer(solver=solver, max_epochs=30, accelerator='cpu', batch_size=20)
trainer.train() trainer.train()
assert False
def test_train_cpu(): def test_train_cpu():
@@ -107,6 +105,7 @@ def test_train_cpu():
trainer.train() trainer.train()
# def test_train_restore(): # def test_train_restore():
# tmpdir = "tests/tmp_restore" # tmpdir = "tests/tmp_restore"
# solver = SupervisedSolver(problem=problem, # solver = SupervisedSolver(problem=problem,