committed by
Nicola Demo
parent
f0d68b34c7
commit
30f865d912
@@ -20,7 +20,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
|
||||
self._discretized_domains = {}
|
||||
|
||||
for name, domain in self.domains.items():
|
||||
@@ -28,18 +27,19 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self._discretized_domains[name] = domain
|
||||
|
||||
for condition_name in self.conditions:
|
||||
self.conditions[condition_name]._problem = self
|
||||
self.conditions[condition_name].set_problem(self)
|
||||
|
||||
# # variable storing all points
|
||||
# self.input_pts = {}
|
||||
self.input_pts = {}
|
||||
|
||||
# # varible to check if sampling is done. If no location
|
||||
# # element is presented in Condition this variable is set to true
|
||||
# self._have_sampled_points = {}
|
||||
# for condition_name in self.conditions:
|
||||
# self._have_sampled_points[condition_name] = False
|
||||
for condition_name in self.conditions:
|
||||
self._discretized_domains[condition_name] = False
|
||||
|
||||
# # put in self.input_pts all the points that we don't need to sample
|
||||
# self._span_condition_points()
|
||||
self._span_condition_points()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
@@ -125,7 +125,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
if hasattr(condition, "input_points"):
|
||||
samples = condition.input_points
|
||||
self.input_pts[condition_name] = samples
|
||||
self._have_sampled_points[condition_name] = True
|
||||
self._discretized_domains[condition_name] = True
|
||||
if hasattr(self, "unknown_parameter_domain"):
|
||||
# initialize the unknown parameters of the inverse problem given
|
||||
# the domain the user gives
|
||||
@@ -141,7 +141,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
)
|
||||
|
||||
def discretise_domain(
|
||||
self, n, mode="random", variables="all", locations="all"
|
||||
self, n, mode="random", variables="all", domains="all"
|
||||
):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
@@ -192,31 +192,37 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
|
||||
# check consistency location
|
||||
locations_to_sample = [
|
||||
condition
|
||||
for condition in self.conditions
|
||||
if hasattr(self.conditions[condition], "location")
|
||||
]
|
||||
if locations == "all":
|
||||
# only locations that can be sampled
|
||||
locations = locations_to_sample
|
||||
else:
|
||||
check_consistency(locations, str)
|
||||
# # check consistency location # TODO: check if this is needed (from 0.1)
|
||||
# locations_to_sample = [
|
||||
# condition
|
||||
# for condition in self.conditions
|
||||
# if hasattr(self.conditions[condition], "location")
|
||||
# ]
|
||||
# if locations == "all":
|
||||
# # only locations that can be sampled
|
||||
# locations = locations_to_sample
|
||||
# else:
|
||||
# check_consistency(locations, str)
|
||||
|
||||
if sorted(locations) != sorted(locations_to_sample):
|
||||
# if sorted(locations) != sorted(locations_to_sample):
|
||||
if domains == "all":
|
||||
domains = [condition for condition in self.conditions]
|
||||
else:
|
||||
check_consistency(domains, str)
|
||||
print(domains)
|
||||
if sorted(domains) != sorted(self.conditions):
|
||||
TypeError(
|
||||
f"Wrong locations for sampling. Location ",
|
||||
f"should be in {locations_to_sample}.",
|
||||
)
|
||||
|
||||
# sampling
|
||||
for location in locations:
|
||||
condition = self.conditions[location]
|
||||
for d in domains:
|
||||
condition = self.conditions[d]
|
||||
|
||||
# we try to check if we have already sampled
|
||||
try:
|
||||
already_sampled = [self.input_pts[location]]
|
||||
already_sampled = [self.input_pts[d]]
|
||||
# if we have not sampled, a key error is thrown
|
||||
except KeyError:
|
||||
already_sampled = []
|
||||
@@ -225,25 +231,27 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# but we want to sample again we set already_sampled
|
||||
# to an empty list since we need to sample again, and
|
||||
# self._have_sampled_points to False.
|
||||
if self._have_sampled_points[location]:
|
||||
if self._discretized_domains[d]:
|
||||
already_sampled = []
|
||||
self._have_sampled_points[location] = False
|
||||
|
||||
self._discretized_domains[d] = False
|
||||
print(condition.domain)
|
||||
print(d)
|
||||
# build samples
|
||||
samples = [
|
||||
condition.location.sample(n=n, mode=mode, variables=variables)
|
||||
self.domains[d].sample(n=n, mode=mode, variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
self.input_pts[location] = pts
|
||||
self.input_pts[d] = pts
|
||||
|
||||
# the condition is sampled if input_pts contains all labels
|
||||
if sorted(self.input_pts[location].labels) == sorted(
|
||||
if sorted(self.input_pts[d].labels) == sorted(
|
||||
self.input_variables
|
||||
):
|
||||
self._have_sampled_points[location] = True
|
||||
self.input_pts[location] = self.input_pts[location].extract(
|
||||
sorted(self.input_variables)
|
||||
)
|
||||
# self._have_sampled_points[location] = True
|
||||
# self.input_pts[location] = self.input_pts[location].extract(
|
||||
# sorted(self.input_variables)
|
||||
# )
|
||||
self._have_sampled_points[d] = True
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user