Fix bugs in 0.2 (#344)

* Fix some bugs
This commit is contained in:
FilippoOlivo
2024-09-12 18:12:59 +02:00
committed by Nicola Demo
parent f0d68b34c7
commit 30f865d912
11 changed files with 112 additions and 55 deletions

View File

@@ -84,14 +84,15 @@ class Condition:
return DomainEquationCondition(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
'''
if (
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
and sorted(kwargs.keys()) != sorted(["location", "equation"])
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
):
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
raise TypeError("`input_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
@@ -103,3 +104,4 @@ class Condition:
for key, value in kwargs.items():
setattr(self, key, value)
'''

View File

@@ -15,4 +15,7 @@ class ConditionInterface(metaclass=ABCMeta):
:param model: The model to evaluate the condition.
:return: The residual of the condition.
"""
pass
pass
def set_problem(self, problem):
self._problem = problem

View File

@@ -15,6 +15,12 @@ class DomainEquationCondition(ConditionInterface):
self.domain = domain
self.equation = equation
def residual(self, model):
"""
Compute the residual of the condition.
"""
self.batch_residual(model, self.domain, self.equation)
@staticmethod
def batch_residual(model, input_pts, equation):
"""
@@ -22,7 +28,7 @@ class DomainEquationCondition(ConditionInterface):
output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points.
:param torch.Tensor input_pts: The input points.
:param torch.Tensor equation: The output points.
"""
return equation.residual(model(input_pts))
return equation.residual(input_pts, model(input_pts))

View File

@@ -40,4 +40,5 @@ class DomainOutputCondition(ConditionInterface):
:param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points.
"""
return output_points - model(input_points)

View File

@@ -1,4 +1,5 @@
import torch
import torch
from .domain_interface import DomainInterface
from ..label_tensor import LabelTensor

View File

@@ -5,7 +5,6 @@ import torch
from torch import Tensor
# class LabelTensor(torch.Tensor):
# """Torch tensor with a label for any column."""
@@ -307,13 +306,13 @@ from torch import Tensor
# s = "no labels\n"
# s += super().__str__()
# return s
def issubset(a, b):
"""
Check if a is a subset of b.
"""
return set(a).issubset(set(b))
class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
@@ -403,6 +402,10 @@ class LabelTensor(torch.Tensor):
return LabelTensor(new_tensor, label_to_extract)
def __str__(self):
"""
returns a string with the representation of the class
"""
s = ''
for key, value in self.labels.items():
s += f"{key}: {value}\n"
@@ -431,4 +434,32 @@ class LabelTensor(torch.Tensor):
@property
def dtype(self):
return super().dtype
return super().dtype
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
"""
tmp = super().to(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
:meth:`torch.Tensor.clone`.
:return: A copy of the tensor.
:rtype: LabelTensor
"""
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out

View File

@@ -20,7 +20,6 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
self._discretized_domains = {}
for name, domain in self.domains.items():
@@ -28,18 +27,19 @@ class AbstractProblem(metaclass=ABCMeta):
self._discretized_domains[name] = domain
for condition_name in self.conditions:
self.conditions[condition_name]._problem = self
self.conditions[condition_name].set_problem(self)
# # variable storing all points
# self.input_pts = {}
self.input_pts = {}
# # varible to check if sampling is done. If no location
# # element is presented in Condition this variable is set to true
# self._have_sampled_points = {}
# for condition_name in self.conditions:
# self._have_sampled_points[condition_name] = False
for condition_name in self.conditions:
self._discretized_domains[condition_name] = False
# # put in self.input_pts all the points that we don't need to sample
# self._span_condition_points()
self._span_condition_points()
def __deepcopy__(self, memo):
"""
@@ -125,7 +125,7 @@ class AbstractProblem(metaclass=ABCMeta):
if hasattr(condition, "input_points"):
samples = condition.input_points
self.input_pts[condition_name] = samples
self._have_sampled_points[condition_name] = True
self._discretized_domains[condition_name] = True
if hasattr(self, "unknown_parameter_domain"):
# initialize the unknown parameters of the inverse problem given
# the domain the user gives
@@ -141,7 +141,7 @@ class AbstractProblem(metaclass=ABCMeta):
)
def discretise_domain(
self, n, mode="random", variables="all", locations="all"
self, n, mode="random", variables="all", domains="all"
):
"""
Generate a set of points to span the `Location` of all the conditions of
@@ -192,31 +192,37 @@ class AbstractProblem(metaclass=ABCMeta):
f"should be in {self.input_variables}.",
)
# check consistency location
locations_to_sample = [
condition
for condition in self.conditions
if hasattr(self.conditions[condition], "location")
]
if locations == "all":
# only locations that can be sampled
locations = locations_to_sample
else:
check_consistency(locations, str)
# # check consistency location # TODO: check if this is needed (from 0.1)
# locations_to_sample = [
# condition
# for condition in self.conditions
# if hasattr(self.conditions[condition], "location")
# ]
# if locations == "all":
# # only locations that can be sampled
# locations = locations_to_sample
# else:
# check_consistency(locations, str)
if sorted(locations) != sorted(locations_to_sample):
# if sorted(locations) != sorted(locations_to_sample):
if domains == "all":
domains = [condition for condition in self.conditions]
else:
check_consistency(domains, str)
print(domains)
if sorted(domains) != sorted(self.conditions):
TypeError(
f"Wrong locations for sampling. Location ",
f"should be in {locations_to_sample}.",
)
# sampling
for location in locations:
condition = self.conditions[location]
for d in domains:
condition = self.conditions[d]
# we try to check if we have already sampled
try:
already_sampled = [self.input_pts[location]]
already_sampled = [self.input_pts[d]]
# if we have not sampled, a key error is thrown
except KeyError:
already_sampled = []
@@ -225,25 +231,27 @@ class AbstractProblem(metaclass=ABCMeta):
# but we want to sample again we set already_sampled
# to an empty list since we need to sample again, and
# self._have_sampled_points to False.
if self._have_sampled_points[location]:
if self._discretized_domains[d]:
already_sampled = []
self._have_sampled_points[location] = False
self._discretized_domains[d] = False
print(condition.domain)
print(d)
# build samples
samples = [
condition.location.sample(n=n, mode=mode, variables=variables)
self.domains[d].sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
self.input_pts[location] = pts
self.input_pts[d] = pts
# the condition is sampled if input_pts contains all labels
if sorted(self.input_pts[location].labels) == sorted(
if sorted(self.input_pts[d].labels) == sorted(
self.input_variables
):
self._have_sampled_points[location] = True
self.input_pts[location] = self.input_pts[location].extract(
sorted(self.input_variables)
)
# self._have_sampled_points[location] = True
# self.input_pts[location] = self.input_pts[location].extract(
# sorted(self.input_variables)
# )
self._have_sampled_points[d] = True
def add_points(self, new_points):
"""

View File

@@ -134,8 +134,6 @@ class SupervisedSolver(SolverInterface):
condition = self.problem.conditions[condition_name]
pts = batch.input
out = batch.output
print(out)
print(pts)
if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.")