Update of LabelTensor class and fix Simplex domain (#362)

*Implement new methods in LabelTensor and fix operators
This commit is contained in:
Filippo Olivo
2024-10-10 18:26:52 +02:00
committed by Nicola Demo
parent fdb8f65143
commit 7528f6ef74
19 changed files with 551 additions and 217 deletions

View File

@@ -77,7 +77,7 @@ class Difference(OperationInterface):
5
"""
if mode != self.sample_modes:
if mode not in self.sample_modes:
raise NotImplementedError(
f"{mode} is not a valid mode for sampling."
)

View File

@@ -76,7 +76,7 @@ class Exclusion(OperationInterface):
5
"""
if mode != self.sample_modes:
if mode not in self.sample_modes:
raise NotImplementedError(
f"{mode} is not a valid mode for sampling."
)

View File

@@ -78,7 +78,7 @@ class Intersection(OperationInterface):
5
"""
if mode != self.sample_modes:
if mode not in self.sample_modes:
raise NotImplementedError(
f"{mode} is not a valid mode for sampling."
)

View File

@@ -92,13 +92,12 @@ class SimplexDomain(DomainInterface):
"""
span_dict = {}
for i, coord in enumerate(self.variables):
sorted_vertices = sorted(vertices, key=lambda vertex: vertex[i])
sorted_vertices = torch.sort(vertices[coord].tensor.squeeze())
# respective coord bounded by the lowest and highest values
span_dict[coord] = [
float(sorted_vertices[0][i]),
float(sorted_vertices[-1][i]),
float(sorted_vertices.values[0]),
float(sorted_vertices.values[-1]),
]
return CartesianDomain(span_dict)

View File

@@ -41,7 +41,10 @@ class Union(OperationInterface):
@property
def variables(self):
return list(set([geom.variables for geom in self.geometries]))
variables = []
for geom in self.geometries:
variables+=geom.variables
return list(set(variables))
def is_inside(self, point, check_border=False):
"""