edited utils to take list (#115)

* enhanced difference domain
* refactored utils
* fixed typo
* added tests
---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
Kush
2023-06-19 18:47:52 +02:00
committed by Nicola Demo
parent aaf2bed732
commit 62ec69ccac
9 changed files with 73 additions and 47 deletions

View File

@@ -3,16 +3,17 @@
from .location import Location
from ..label_tensor import LabelTensor
class Difference(Location):
"""
"""
def __init__(self, first, second):
def __init__(self, first, second):
self.first = first
self.second = second
def sample(self, n, mode ='random', variables='all'):
def sample(self, n, mode='random', variables='all'):
"""
"""
assert mode is 'random', 'Only random mode is implemented'
@@ -24,4 +25,4 @@ class Difference(Location):
samples.append(sample.tolist()[0])
import torch
return LabelTensor(torch.tensor(samples), labels=['x', 'y'])
return LabelTensor(torch.tensor(samples), labels=['x', 'y'])

View File

@@ -25,9 +25,9 @@ class Union(Location):
super().__init__()
# union checks
self._check_union_inheritance(geometries)
self._check_union_consistency(geometries)
check_consistency(geometries, Location)
self._check_union_dimensions(geometries)
# assign geometries
self._geometries = geometries
@@ -36,7 +36,7 @@ class Union(Location):
"""
The geometries."""
return self._geometries
@property
def variables(self):
"""
@@ -116,7 +116,7 @@ class Union(Location):
return LabelTensor(torch.cat(sampled_points), labels=[f'{i}' for i in self.variables])
def _check_union_consistency(self, geometries):
def _check_union_dimensions(self, geometries):
"""Check if the dimensions of the geometries are consistent.
:param geometries: Geometries to be checked.
@@ -126,12 +126,3 @@ class Union(Location):
if geometry.variables != geometries[0].variables:
raise NotImplementedError(
f'The geometries need to be the same dimensions. {geometry.variables} is not equal to {geometries[0].variables}')
def _check_union_inheritance(self, geometries):
"""Check if the geometries are inherited from 'pina.geometry.Location'.
param geometries: Geometries to be checked.
:type geometries: list[Location]
"""
for idx, geometry in enumerate(geometries):
check_consistency(geometry, Location, f'geometry[{idx}]')