Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver

This commit is contained in:
FilippoOlivo
2024-10-16 11:24:37 +02:00
committed by Nicola Demo
parent b9753c34b2
commit c9304fb9bb
30 changed files with 770 additions and 784 deletions

View File

@@ -33,7 +33,7 @@ class CartesianDomain(DomainInterface):
@property
def sample_modes(self):
return ["random", "grid", "lh", "chebyshev", "latin"]
@property
def variables(self):
"""Spatial variables.

View File

@@ -55,7 +55,6 @@ class EllipsoidDomain(DomainInterface):
# perform operation only for not fixed variables (if any)
if self.range_:
# convert dict vals to torch [dim, 2] matrix
list_dict_vals = list(self.range_.values())
tmp = torch.tensor(list_dict_vals, dtype=torch.float)
@@ -74,7 +73,7 @@ class EllipsoidDomain(DomainInterface):
@property
def sample_modes(self):
return ["random"]
@property
def variables(self):
"""Spatial variables.

View File

@@ -69,4 +69,4 @@ class OperationInterface(DomainInterface, metaclass=ABCMeta):
if geometry.variables != geometries[0].variables:
raise NotImplementedError(
f"The geometries need to have same dimensions and labels."
)
)

View File

@@ -77,7 +77,7 @@ class SimplexDomain(DomainInterface):
@property
def sample_modes(self):
return ["random"]
@property
def variables(self):
return sorted(self._vertices_matrix.labels)
@@ -144,7 +144,7 @@ class SimplexDomain(DomainInterface):
return all(torch.gt(lambdas, 0.0)) and all(torch.lt(lambdas, 1.0))
return all(torch.ge(lambdas, 0)) and (
any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1))
any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1))
)
def _sample_interior_randomly(self, n, variables):

View File

@@ -37,13 +37,13 @@ class Union(OperationInterface):
def sample_modes(self):
self.sample_modes = list(
set([geom.sample_modes for geom in self.geometries])
)
)
@property
def variables(self):
variables = []
for geom in self.geometries:
variables+=geom.variables
variables += geom.variables
return list(set(variables))
def is_inside(self, point, check_border=False):