minor changes/ trainer update

This commit is contained in:
Dario Coscia
2024-10-10 19:24:46 +02:00
committed by Nicola Demo
parent 7528f6ef74
commit b9753c34b2
8 changed files with 69 additions and 46 deletions

View File

@@ -20,7 +20,7 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
# create collector to manage problem data
self.collector = Collector(self)
self._collector = Collector(self)
# create hook conditions <-> problems
for condition_name in self.conditions:
@@ -33,7 +33,12 @@ class AbstractProblem(metaclass=ABCMeta):
# points are ready.
self.collector.store_fixed_data()
@property
def collector(self):
return self._collector
# TODO this should be erase when dataloading will interface collector,
# kept only for back compatibility
@property
def input_pts(self):
to_return = {}
@@ -41,10 +46,6 @@ class AbstractProblem(metaclass=ABCMeta):
if 'input_points' in v.keys():
to_return[k] = v['input_points']
return to_return
@property
def _have_sampled_points(self):
return self.collector.is_conditions_ready
def __deepcopy__(self, memo):
"""
@@ -160,7 +161,9 @@ class AbstractProblem(metaclass=ABCMeta):
# check correct location
if locations == "all":
locations = [name for name in self.conditions.keys()]
locations = [name for name in self.conditions.keys()
if isinstance(self.conditions[name],
DomainEquationCondition)]
else:
if not isinstance(locations, (list)):
locations = [locations]
@@ -168,7 +171,7 @@ class AbstractProblem(metaclass=ABCMeta):
if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError(
f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
)
# store data