Files
PINA/pina/problem/abstract_problem.py
2025-03-19 17:46:35 +01:00

205 lines
6.9 KiB
Python

""" Module for AbstractProblem class """
from abc import ABCMeta, abstractmethod
from ..utils import check_consistency
from ..domain import DomainInterface
from ..condition.domain_equation_condition import DomainEquationCondition
from ..collector import Collector
from copy import deepcopy
class AbstractProblem(metaclass=ABCMeta):
"""
The abstract `AbstractProblem` class. All the class defining a PINA Problem
should be inheritied from this class.
In the definition of a PINA problem, the fundamental elements are:
the output variables, the condition(s), and the domain(s) where the
conditions are applied.
"""
def __init__(self):
self.discretised_domains = {}
# create collector to manage problem data
# create hook conditions <-> problems
for condition_name in self.conditions:
self.conditions[condition_name].problem = self
# store in collector all the available fixed points
# note that some points could not be stored at this stage (e.g. when
# sampling locations). To check that all data points are ready for
# training all type self.collector.full, which returns true if all
# points are ready.
# self.collector.store_fixed_data()
self._batching_dimension = 0
# @property
# def collector(self):
# return self._collector
@property
def batching_dimension(self):
return self._batching_dimension
@batching_dimension.setter
def batching_dimension(self, value):
self._batching_dimension = value
# TODO this should be erase when dataloading will interface collector,
# kept only for back compatibility
@property
def input_pts(self):
to_return = {}
for k, v in self.collector.data_collections.items():
if 'input_points' in v.keys():
to_return[k] = v['input_points']
return to_return
def __deepcopy__(self, memo):
"""
Implements deepcopy for the
:class:`~pina.problem.abstract_problem.AbstractProblem` class.
:param dict memo: Memory dictionary, to avoid excess copy
:return: The deep copy of the
:class:`~pina.problem.abstract_problem.AbstractProblem` class
:rtype: AbstractProblem
"""
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, deepcopy(v, memo))
return result
@property
def are_all_domains_discretised(self):
"""
Check if all the domains are discretised.
:return: True if all the domains are discretised, False otherwise
:rtype: bool
"""
return all(
[
domain in self.discretised_domains
for domain in self.domains.keys()
]
)
@property
def input_variables(self):
"""
The input variables of the AbstractProblem, whose type depends on the
type of domain (spatial, temporal, and parameter).
:return: the input variables of self
:rtype: list
"""
variables = []
if hasattr(self, "spatial_variables"):
variables += self.spatial_variables
if hasattr(self, "temporal_variable"):
variables += self.temporal_variable
if hasattr(self, "parameters"):
variables += self.parameters
return variables
@input_variables.setter
def input_variables(self, variables):
raise RuntimeError
@property
@abstractmethod
def output_variables(self):
"""
The output variables of the problem.
"""
pass
@property
@abstractmethod
def conditions(self):
"""
The conditions of the problem.
"""
return self._conditions
def discretise_domain(self,
n,
mode="random",
variables="all",
domains="all"):
"""
Generate a set of points to span the `Location` of all the conditions of
the problem.
:param n: Number of points to sample, see Note below
for reference.
:type n: int
:param mode: Mode for sampling, defaults to ``random``.
Available modes include: random sampling, ``random``;
latin hypercube sampling, ``latin`` or ``lh``;
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
:param variables: problem's variables to be sampled, defaults to 'all'.
:type variables: str | list[str]
:param domain: problem's domain from where to sample, defaults to 'all'.
:type locations: str
:Example:
>>> pinn.discretise_domain(n=10, mode='grid')
>>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1'])
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
.. warning::
``random`` is currently the only implemented ``mode`` for all geometries, i.e.
``EllipsoidDomain``, ``CartesianDomain``, ``SimplexDomain`` and the geometries
compositions ``Union``, ``Difference``, ``Exclusion``, ``Intersection``. The
modes ``latin`` or ``lh``, ``chebyshev``, ``grid`` are only implemented for
``CartesianDomain``.
"""
# check consistecy n, mode, variables, locations
check_consistency(n, int)
check_consistency(mode, str)
check_consistency(variables, str)
check_consistency(domains, (list, str))
# check correct sampling mode
# if mode not in DomainInterface.available_sampling_modes:
# raise TypeError(f"mode {mode} not valid.")
# check correct variables
if variables == "all":
variables = self.input_variables
for variable in variables:
if variable not in self.input_variables:
TypeError(
f"Wrong variables for sampling. Variables ",
f"should be in {self.input_variables}.",
)
# check correct location
if domains == "all":
domains = self.domains.keys()
elif not isinstance(domains, (list)):
domains = [domains]
for domain in domains:
self.discretised_domains[domain] = (
self.domains[domain].sample(n, mode, variables)
)
# if not isinstance(self.conditions[loc], DomainEquationCondition):
# raise TypeError(
# f"Wrong locations passed, locations for sampling "
# f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
# )
# store data
# self.collector.store_sample_domains()
# self.collector.store_sample_domains(n, mode, variables, domain)