Introduce add_points method in AbstractProblem, removed unused comments in Collector class and add the test for add_points and codacy corrections
This commit is contained in:
committed by
Nicola Demo
parent
004cbc00c0
commit
f578b2ed12
@@ -6,6 +6,7 @@ from ..domain import DomainInterface
|
||||
from ..condition.domain_equation_condition import DomainEquationCondition
|
||||
from ..condition import InputPointsEquationCondition
|
||||
from copy import deepcopy
|
||||
from pina import LabelTensor
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
@@ -135,12 +136,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
The conditions of the problem.
|
||||
"""
|
||||
return self._conditions
|
||||
return self.conditions
|
||||
|
||||
def discretise_domain(self,
|
||||
n,
|
||||
mode="random",
|
||||
variables="all",
|
||||
domains="all"):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
@@ -153,10 +153,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
Available modes include: random sampling, ``random``;
|
||||
latin hypercube sampling, ``latin`` or ``lh``;
|
||||
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
||||
:param variables: problem's variables to be sampled, defaults to 'all'.
|
||||
:type variables: str | list[str]
|
||||
:param domain: problem's domain from where to sample, defaults to 'all'.
|
||||
:type locations: str
|
||||
:param domains: problem's domain from where to sample, defaults to 'all'.
|
||||
:type domains: str | list[str]
|
||||
|
||||
:Example:
|
||||
>>> pinn.discretise_domain(n=10, mode='grid')
|
||||
@@ -174,22 +172,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# check consistecy n, mode, variables, locations
|
||||
check_consistency(n, int)
|
||||
check_consistency(mode, str)
|
||||
check_consistency(variables, str)
|
||||
check_consistency(domains, (list, str))
|
||||
|
||||
# check correct sampling mode
|
||||
# if mode not in DomainInterface.available_sampling_modes:
|
||||
# raise TypeError(f"mode {mode} not valid.")
|
||||
|
||||
# check correct variables
|
||||
if variables == "all":
|
||||
variables = self.input_variables
|
||||
for variable in variables:
|
||||
if variable not in self.input_variables:
|
||||
TypeError(
|
||||
f"Wrong variables for sampling. Variables ",
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
# check correct location
|
||||
if domains == "all":
|
||||
domains = self.domains.keys()
|
||||
@@ -198,14 +186,16 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
for domain in domains:
|
||||
self.discretised_domains[domain] = (
|
||||
self.domains[domain].sample(n, mode, variables)
|
||||
self.domains[domain].sample(n, mode)
|
||||
)
|
||||
# if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
# raise TypeError(
|
||||
# f"Wrong locations passed, locations for sampling "
|
||||
# f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
# )
|
||||
|
||||
# store data
|
||||
# self.collector.store_sample_domains()
|
||||
# self.collector.store_sample_domains(n, mode, variables, domain)
|
||||
def add_points(self, new_points_dict):
|
||||
"""
|
||||
Add input points to a sampled condition
|
||||
:param new_points_dict: Dictionary of input points (condition_name:
|
||||
LabelTensor)
|
||||
:raises RuntimeError: if at least one condition is not already sampled
|
||||
"""
|
||||
for k, v in new_points_dict.items():
|
||||
self.discretised_domains[k] = LabelTensor.vstack(
|
||||
[self.discretised_domains[k], v])
|
||||
|
||||
Reference in New Issue
Block a user