clean logic, fix problems for tutorial1

This commit is contained in:
Nicola Demo
2025-02-06 14:28:17 +01:00
parent 7702427e8d
commit effd1e83bb
5 changed files with 105 additions and 76 deletions

View File

@@ -20,8 +20,8 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
self.discretised_domains = {}
# create collector to manage problem data
self._collector = Collector(self)
# create hook conditions <-> problems
for condition_name in self.conditions:
@@ -32,12 +32,12 @@ class AbstractProblem(metaclass=ABCMeta):
# sampling locations). To check that all data points are ready for
# training all type self.collector.full, which returns true if all
# points are ready.
self.collector.store_fixed_data()
# self.collector.store_fixed_data()
self._batching_dimension = 0
@property
def collector(self):
return self._collector
# @property
# def collector(self):
# return self._collector
@property
def batching_dimension(self):
@@ -74,6 +74,21 @@ class AbstractProblem(metaclass=ABCMeta):
setattr(result, k, deepcopy(v, memo))
return result
@property
def are_all_domains_discretised(self):
"""
Check if all the domains are discretised.
:return: True if all the domains are discretised, False otherwise
:rtype: bool
"""
return all(
[
domain in self.discretised_domains
for domain in self.domains.keys()
]
)
@property
def input_variables(self):
"""
@@ -120,7 +135,7 @@ class AbstractProblem(metaclass=ABCMeta):
n,
mode="random",
variables="all",
locations="all"):
domains="all"):
"""
Generate a set of points to span the `Location` of all the conditions of
the problem.
@@ -134,12 +149,12 @@ class AbstractProblem(metaclass=ABCMeta):
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
:param variables: problem's variables to be sampled, defaults to 'all'.
:type variables: str | list[str]
:param locations: problem's locations from where to sample, defaults to 'all'.
:param domain: problem's domain from where to sample, defaults to 'all'.
:type locations: str
:Example:
>>> pinn.discretise_domain(n=10, mode='grid')
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1'])
>>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1'])
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
.. warning::
@@ -154,11 +169,11 @@ class AbstractProblem(metaclass=ABCMeta):
check_consistency(n, int)
check_consistency(mode, str)
check_consistency(variables, str)
check_consistency(locations, str)
check_consistency(domains, (list, str))
# check correct sampling mode
if mode not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {mode} not valid.")
# if mode not in DomainInterface.available_sampling_modes:
# raise TypeError(f"mode {mode} not valid.")
# check correct variables
if variables == "all":
@@ -170,23 +185,21 @@ class AbstractProblem(metaclass=ABCMeta):
f"should be in {self.input_variables}.",
)
# check correct location
if locations == "all":
locations = [
name for name in self.conditions.keys()
if isinstance(self.conditions[name], DomainEquationCondition)
]
else:
if not isinstance(locations, (list)):
locations = [locations]
for loc in locations:
if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError(
f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
)
if domains == "all":
domains = self.domains.keys()
elif not isinstance(domains, (list)):
domains = [domains]
for domain in domains:
self.discretised_domains[domain] = (
self.domains[domain].sample(n, mode, variables)
)
# if not isinstance(self.conditions[loc], DomainEquationCondition):
# raise TypeError(
# f"Wrong locations passed, locations for sampling "
# f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
# )
# store data
self.collector.store_sample_domains(n, mode, variables, locations)
def add_points(self, new_points_dict):
self.collector.add_points(new_points_dict)
# self.collector.store_sample_domains()
# self.collector.store_sample_domains(n, mode, variables, domain)