Files
PINA/pina/condition/condition.py
FilippoOlivo 30f865d912 Fix bugs in 0.2 (#344)
* Fix some bugs
2025-03-19 17:46:33 +01:00

107 lines
4.2 KiB
Python

""" Condition module. """
from ..label_tensor import LabelTensor
from ..domain import DomainInterface
from ..equation.equation import Equation
from . import DomainOutputCondition, DomainEquationCondition
def dummy(a):
"""Dummy function for testing purposes."""
return None
class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in three ways:
1. By specifying the input and output points of the condition; in such a
case, the model is trained to produce the output points given the input
points.
2. By specifying the location and the equation of the condition; in such
a case, the model is trained to minimize the equation residual by
evaluating it at some samples of the location.
3. By specifying the input points and the equation of the condition; in
such a case, the model is trained to minimize the equation residual by
evaluating it at the passed input points.
Example::
>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
>>> def example_dirichlet(input_, output_):
>>> value = 0.0
>>> return output_.extract(['u']) - value
>>> example_input_pts = LabelTensor(
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
>>>
>>> Condition(
>>> input_points=example_input_pts,
>>> output_points=example_output_pts)
>>> Condition(
>>> location=example_domain,
>>> equation=example_dirichlet)
>>> Condition(
>>> input_points=example_input_pts,
>>> equation=example_dirichlet)
"""
# def _dictvalue_isinstance(self, dict_, key_, class_):
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
# if key_ not in dict_.keys():
# return True
# return isinstance(dict_[key_], class_)
# def __init__(self, *args, **kwargs):
# """
# Constructor for the `Condition` class.
# """
# self.data_weight = kwargs.pop("data_weight", 1.0)
# if len(args) != 0:
# raise ValueError(
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
# )
def __new__(cls, *args, **kwargs):
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
return DomainOutputCondition(
domain=kwargs["input_points"],
output_points=kwargs["output_points"]
)
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
return DomainOutputCondition(**kwargs)
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
return DomainEquationCondition(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
'''
if (
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
and sorted(kwargs.keys()) != sorted(["location", "equation"])
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
):
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
raise TypeError("`input_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
raise TypeError("`output_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "location", Location):
raise TypeError("`location` must be a Location.")
if not self._dictvalue_isinstance(kwargs, "equation", Equation):
raise TypeError("`equation` must be a Equation.")
for key, value in kwargs.items():
setattr(self, key, value)
'''