""" Module for AbstractProblem class """ from abc import ABCMeta, abstractmethod from ..utils import check_consistency from ..domain import DomainInterface from ..condition.domain_equation_condition import DomainEquationCondition from ..condition import InputPointsEquationCondition from copy import deepcopy from pina import LabelTensor class AbstractProblem(metaclass=ABCMeta): """ The abstract `AbstractProblem` class. All the class defining a PINA Problem should be inherited from this class. In the definition of a PINA problem, the fundamental elements are: the output variables, the condition(s), and the domain(s) where the conditions are applied. """ def __init__(self): self.discretised_domains = {} # create collector to manage problem data # create hook conditions <-> problems for condition_name in self.conditions: self.conditions[condition_name].problem = self self._batching_dimension = 0 # Store in domains dict all the domains object directly passed to # ConditionInterface. Done for back compatibility with PINA <0.2 if not hasattr(self, "domains"): self.domains = {} for cond_name, cond in self.conditions.items(): if isinstance(cond, (DomainEquationCondition, InputPointsEquationCondition)): if isinstance(cond.domain, DomainInterface): self.domains[cond_name] = cond.domain cond.domain = cond_name # @property # def collector(self): # return self._collector @property def batching_dimension(self): return self._batching_dimension @batching_dimension.setter def batching_dimension(self, value): self._batching_dimension = value # TODO this should be erase when dataloading will interface collector, # kept only for back compatibility @property def input_pts(self): to_return = {} for cond_name, cond in self.conditions.items(): if hasattr(cond, "input_points"): to_return[cond_name] = cond.input_points elif hasattr(cond, "domain"): to_return[cond_name] = self.discretised_domains[cond.domain] return to_return def __deepcopy__(self, memo): """ Implements deepcopy for the :class:`~pina.problem.abstract_problem.AbstractProblem` class. :param dict memo: Memory dictionary, to avoid excess copy :return: The deep copy of the :class:`~pina.problem.abstract_problem.AbstractProblem` class :rtype: AbstractProblem """ cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): setattr(result, k, deepcopy(v, memo)) return result @property def are_all_domains_discretised(self): """ Check if all the domains are discretised. :return: True if all the domains are discretised, False otherwise :rtype: bool """ return all( [ domain in self.discretised_domains for domain in self.domains.keys() ] ) @property def input_variables(self): """ The input variables of the AbstractProblem, whose type depends on the type of domain (spatial, temporal, and parameter). :return: the input variables of self :rtype: list """ variables = [] if hasattr(self, "spatial_variables"): variables += self.spatial_variables if hasattr(self, "temporal_variable"): variables += self.temporal_variable if hasattr(self, "parameters"): variables += self.parameters return variables @input_variables.setter def input_variables(self, variables): raise RuntimeError @property @abstractmethod def output_variables(self): """ The output variables of the problem. """ pass @property @abstractmethod def conditions(self): """ The conditions of the problem. """ return self.conditions def discretise_domain(self, n, mode="random", domains="all"): """ Generate a set of points to span the `Location` of all the conditions of the problem. :param n: Number of points to sample, see Note below for reference. :type n: int :param mode: Mode for sampling, defaults to ``random``. Available modes include: random sampling, ``random``; latin hypercube sampling, ``latin`` or ``lh``; chebyshev sampling, ``chebyshev``; grid sampling ``grid``. :param domains: problem's domain from where to sample, defaults to 'all'. :type domains: str | list[str] :Example: >>> pinn.discretise_domain(n=10, mode='grid') >>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1']) >>> pinn.discretise_domain(n=10, mode='grid', variables=['x']) .. warning:: ``random`` is currently the only implemented ``mode`` for all geometries, i.e. ``EllipsoidDomain``, ``CartesianDomain``, ``SimplexDomain`` and the geometries compositions ``Union``, ``Difference``, ``Exclusion``, ``Intersection``. The modes ``latin`` or ``lh``, ``chebyshev``, ``grid`` are only implemented for ``CartesianDomain``. """ # check consistecy n, mode, variables, locations check_consistency(n, int) check_consistency(mode, str) check_consistency(domains, (list, str)) # check correct sampling mode # if mode not in DomainInterface.available_sampling_modes: # raise TypeError(f"mode {mode} not valid.") # check correct location if domains == "all": domains = self.domains.keys() elif not isinstance(domains, (list)): domains = [domains] for domain in domains: self.discretised_domains[domain] = ( self.domains[domain].sample(n, mode) ) def add_points(self, new_points_dict): """ Add input points to a sampled condition :param new_points_dict: Dictionary of input points (condition_name: LabelTensor) :raises RuntimeError: if at least one condition is not already sampled """ for k, v in new_points_dict.items(): self.discretised_domains[k] = LabelTensor.vstack( [self.discretised_domains[k], v])