Update of LabelTensor class and fix Simplex domain (#362)
*Implement new methods in LabelTensor and fix operators
This commit is contained in:
committed by
Nicola Demo
parent
fdb8f65143
commit
7528f6ef74
@@ -1,3 +1,6 @@
|
||||
from sympy.strategies.branch import condition
|
||||
|
||||
from . import LabelTensor
|
||||
from .utils import check_consistency, merge_tensors
|
||||
|
||||
class Collector:
|
||||
@@ -51,7 +54,7 @@ class Collector:
|
||||
already_sampled = []
|
||||
# if we have sampled the condition but not all variables
|
||||
else:
|
||||
already_sampled = [self.data_collections[loc].input_points]
|
||||
already_sampled = [self.data_collections[loc]['input_points']]
|
||||
# if the condition is ready but we want to sample again
|
||||
else:
|
||||
self.is_conditions_ready[loc] = False
|
||||
@@ -63,10 +66,24 @@ class Collector:
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
if (
|
||||
sorted(self.data_collections[loc].input_points.labels)
|
||||
==
|
||||
sorted(self.problem.input_variables)
|
||||
set(pts.labels).issubset(sorted(self.problem.input_variables))
|
||||
):
|
||||
self.is_conditions_ready[loc] = True
|
||||
pts = pts.sort_labels()
|
||||
if sorted(pts.labels)==sorted(self.problem.input_variables):
|
||||
self.is_conditions_ready[loc] = True
|
||||
values = [pts, condition.equation]
|
||||
self.data_collections[loc] = dict(zip(keys, values))
|
||||
self.data_collections[loc] = dict(zip(keys, values))
|
||||
else:
|
||||
raise RuntimeError('Try to sample variables which are not in problem defined in the problem')
|
||||
|
||||
def add_points(self, new_points_dict):
|
||||
"""
|
||||
Add input points to a sampled condition
|
||||
|
||||
:param new_points_dict: Dictonary of input points (condition_name: LabelTensor)
|
||||
:raises RuntimeError: if at least one condition is not already sampled
|
||||
"""
|
||||
for k,v in new_points_dict.items():
|
||||
if not self.is_conditions_ready[k]:
|
||||
raise RuntimeError('Cannot add points on a non sampled condition')
|
||||
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
|
||||
Reference in New Issue
Block a user