Update of LabelTensor class and fix Simplex domain (#362)

*Implement new methods in LabelTensor and fix operators
This commit is contained in:
Filippo Olivo
2024-10-10 18:26:52 +02:00
committed by Nicola Demo
parent fdb8f65143
commit 7528f6ef74
19 changed files with 551 additions and 217 deletions

View File

@@ -1,3 +1,6 @@
from sympy.strategies.branch import condition
from . import LabelTensor
from .utils import check_consistency, merge_tensors
class Collector:
@@ -51,7 +54,7 @@ class Collector:
already_sampled = []
# if we have sampled the condition but not all variables
else:
already_sampled = [self.data_collections[loc].input_points]
already_sampled = [self.data_collections[loc]['input_points']]
# if the condition is ready but we want to sample again
else:
self.is_conditions_ready[loc] = False
@@ -63,10 +66,24 @@ class Collector:
] + already_sampled
pts = merge_tensors(samples)
if (
sorted(self.data_collections[loc].input_points.labels)
==
sorted(self.problem.input_variables)
set(pts.labels).issubset(sorted(self.problem.input_variables))
):
self.is_conditions_ready[loc] = True
pts = pts.sort_labels()
if sorted(pts.labels)==sorted(self.problem.input_variables):
self.is_conditions_ready[loc] = True
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))
self.data_collections[loc] = dict(zip(keys, values))
else:
raise RuntimeError('Try to sample variables which are not in problem defined in the problem')
def add_points(self, new_points_dict):
"""
Add input points to a sampled condition
:param new_points_dict: Dictonary of input points (condition_name: LabelTensor)
:raises RuntimeError: if at least one condition is not already sampled
"""
for k,v in new_points_dict.items():
if not self.is_conditions_ready[k]:
raise RuntimeError('Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)