clean logic, fix problems for tutorial1
This commit is contained in:
@@ -20,8 +20,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.discretised_domains = {}
|
||||
# create collector to manage problem data
|
||||
self._collector = Collector(self)
|
||||
|
||||
# create hook conditions <-> problems
|
||||
for condition_name in self.conditions:
|
||||
@@ -32,12 +32,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# sampling locations). To check that all data points are ready for
|
||||
# training all type self.collector.full, which returns true if all
|
||||
# points are ready.
|
||||
self.collector.store_fixed_data()
|
||||
# self.collector.store_fixed_data()
|
||||
self._batching_dimension = 0
|
||||
|
||||
@property
|
||||
def collector(self):
|
||||
return self._collector
|
||||
# @property
|
||||
# def collector(self):
|
||||
# return self._collector
|
||||
|
||||
@property
|
||||
def batching_dimension(self):
|
||||
@@ -74,6 +74,21 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
setattr(result, k, deepcopy(v, memo))
|
||||
return result
|
||||
|
||||
@property
|
||||
def are_all_domains_discretised(self):
|
||||
"""
|
||||
Check if all the domains are discretised.
|
||||
|
||||
:return: True if all the domains are discretised, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
return all(
|
||||
[
|
||||
domain in self.discretised_domains
|
||||
for domain in self.domains.keys()
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
"""
|
||||
@@ -120,7 +135,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
n,
|
||||
mode="random",
|
||||
variables="all",
|
||||
locations="all"):
|
||||
domains="all"):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
@@ -134,12 +149,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
||||
:param variables: problem's variables to be sampled, defaults to 'all'.
|
||||
:type variables: str | list[str]
|
||||
:param locations: problem's locations from where to sample, defaults to 'all'.
|
||||
:param domain: problem's domain from where to sample, defaults to 'all'.
|
||||
:type locations: str
|
||||
|
||||
:Example:
|
||||
>>> pinn.discretise_domain(n=10, mode='grid')
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1'])
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1'])
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
|
||||
|
||||
.. warning::
|
||||
@@ -154,11 +169,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
check_consistency(n, int)
|
||||
check_consistency(mode, str)
|
||||
check_consistency(variables, str)
|
||||
check_consistency(locations, str)
|
||||
check_consistency(domains, (list, str))
|
||||
|
||||
# check correct sampling mode
|
||||
if mode not in DomainInterface.available_sampling_modes:
|
||||
raise TypeError(f"mode {mode} not valid.")
|
||||
# if mode not in DomainInterface.available_sampling_modes:
|
||||
# raise TypeError(f"mode {mode} not valid.")
|
||||
|
||||
# check correct variables
|
||||
if variables == "all":
|
||||
@@ -170,23 +185,21 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
# check correct location
|
||||
if locations == "all":
|
||||
locations = [
|
||||
name for name in self.conditions.keys()
|
||||
if isinstance(self.conditions[name], DomainEquationCondition)
|
||||
]
|
||||
else:
|
||||
if not isinstance(locations, (list)):
|
||||
locations = [locations]
|
||||
for loc in locations:
|
||||
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
raise TypeError(
|
||||
f"Wrong locations passed, locations for sampling "
|
||||
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
)
|
||||
if domains == "all":
|
||||
domains = self.domains.keys()
|
||||
elif not isinstance(domains, (list)):
|
||||
domains = [domains]
|
||||
|
||||
for domain in domains:
|
||||
self.discretised_domains[domain] = (
|
||||
self.domains[domain].sample(n, mode, variables)
|
||||
)
|
||||
# if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
# raise TypeError(
|
||||
# f"Wrong locations passed, locations for sampling "
|
||||
# f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
# )
|
||||
|
||||
# store data
|
||||
self.collector.store_sample_domains(n, mode, variables, locations)
|
||||
|
||||
def add_points(self, new_points_dict):
|
||||
self.collector.add_points(new_points_dict)
|
||||
# self.collector.store_sample_domains()
|
||||
# self.collector.store_sample_domains(n, mode, variables, domain)
|
||||
Reference in New Issue
Block a user