diff --git a/pina/geometry/difference_domain.py b/pina/geometry/difference_domain.py index a433e40..7b99214 100644 --- a/pina/geometry/difference_domain.py +++ b/pina/geometry/difference_domain.py @@ -3,16 +3,17 @@ from .location import Location from ..label_tensor import LabelTensor + class Difference(Location): """ """ - def __init__(self, first, second): + def __init__(self, first, second): self.first = first self.second = second - def sample(self, n, mode ='random', variables='all'): + def sample(self, n, mode='random', variables='all'): """ """ assert mode is 'random', 'Only random mode is implemented' @@ -24,4 +25,4 @@ class Difference(Location): samples.append(sample.tolist()[0]) import torch - return LabelTensor(torch.tensor(samples), labels=['x', 'y']) \ No newline at end of file + return LabelTensor(torch.tensor(samples), labels=['x', 'y']) diff --git a/pina/geometry/union_domain.py b/pina/geometry/union_domain.py index 74be819..f0d92cd 100644 --- a/pina/geometry/union_domain.py +++ b/pina/geometry/union_domain.py @@ -25,9 +25,9 @@ class Union(Location): super().__init__() # union checks - self._check_union_inheritance(geometries) - self._check_union_consistency(geometries) - + check_consistency(geometries, Location) + self._check_union_dimensions(geometries) + # assign geometries self._geometries = geometries @@ -36,7 +36,7 @@ class Union(Location): """ The geometries.""" return self._geometries - + @property def variables(self): """ @@ -116,7 +116,7 @@ class Union(Location): return LabelTensor(torch.cat(sampled_points), labels=[f'{i}' for i in self.variables]) - def _check_union_consistency(self, geometries): + def _check_union_dimensions(self, geometries): """Check if the dimensions of the geometries are consistent. :param geometries: Geometries to be checked. @@ -126,12 +126,3 @@ class Union(Location): if geometry.variables != geometries[0].variables: raise NotImplementedError( f'The geometries need to be the same dimensions. {geometry.variables} is not equal to {geometries[0].variables}') - - def _check_union_inheritance(self, geometries): - """Check if the geometries are inherited from 'pina.geometry.Location'. - - param geometries: Geometries to be checked. - :type geometries: list[Location] - """ - for idx, geometry in enumerate(geometries): - check_consistency(geometry, Location, f'geometry[{idx}]') diff --git a/pina/loss.py b/pina/loss.py index 0073fb2..6ded5f7 100644 --- a/pina/loss.py +++ b/pina/loss.py @@ -108,9 +108,9 @@ class LpLoss(LossInterface): super().__init__(reduction=reduction) # check consistency - check_consistency(p, (str,int,float), 'degree p') + check_consistency(p, (str,int,float)) self.p = p - check_consistency(relative, bool, 'relative') + check_consistency(relative, bool) self.relative = relative def forward(self, input, target): diff --git a/pina/model/network.py b/pina/model/network.py index 752d1df..76d94f6 100644 --- a/pina/model/network.py +++ b/pina/model/network.py @@ -4,20 +4,20 @@ from ..utils import check_consistency class Network(torch.nn.Module): - + def __init__(self, model, extra_features=None): super().__init__() # check model consistency - check_consistency(model, nn.Module, 'torch model') + check_consistency(model, nn.Module) self._model = model - # check consistency and assign extra fatures + # check consistency and assign extra fatures if extra_features is None: self._extra_features = [] else: for feat in extra_features: - check_consistency(feat, nn.Module, 'extra features') + check_consistency(feat, nn.Module) self._extra_features = nn.Sequential(*extra_features) # check model works with inputs @@ -44,4 +44,4 @@ class Network(torch.nn.Module): @property def extra_features(self): - return self._extra_features \ No newline at end of file + return self._extra_features diff --git a/pina/pinn.py b/pina/pinn.py index 86df135..c212ce3 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -48,11 +48,11 @@ class PINN(SolverInterface): super().__init__(model=model, problem=problem, extra_features=extra_features) # check consistency - check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True) - check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs') - check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True) - check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs') - check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False) + check_consistency(optimizer, torch.optim.Optimizer, subclass=True) + check_consistency(optimizer_kwargs, dict) + check_consistency(scheduler, LRScheduler, subclass=True) + check_consistency(scheduler_kwargs, dict) + check_consistency(loss, (LossInterface, _Loss), subclass=False) # assign variables self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs) diff --git a/pina/solver.py b/pina/solver.py index 9625603..18e6e99 100644 --- a/pina/solver.py +++ b/pina/solver.py @@ -20,7 +20,7 @@ class SolverInterface(pl.LightningModule, metaclass=ABCMeta): super().__init__() # check inheritance for pina problem - check_consistency(problem, AbstractProblem, 'pina problem') + check_consistency(problem, AbstractProblem) # assigning class variables (check consistency inside Network class) self._pina_model = Network(model=model, extra_features=extra_features) diff --git a/pina/trainer.py b/pina/trainer.py index 997f14e..4eacf50 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -11,7 +11,7 @@ class Trainer(pl.Trainer): super().__init__(**kwargs) # check inheritance consistency for solver - check_consistency(solver, SolverInterface, 'Solver model') + check_consistency(solver, SolverInterface) self._model = solver # create dataloader diff --git a/pina/utils.py b/pina/utils.py index 798a4ad..a6bf0c5 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -1,4 +1,5 @@ """Utils module""" +from torch.utils.data import Dataset, DataLoader from functools import reduce import types @@ -10,14 +11,14 @@ from .label_tensor import LabelTensor import torch -def check_consistency(object, object_instance, object_name, subclass=False): +def check_consistency(object, object_instance, subclass=False): """Helper function to check object inheritance consistency. Given a specific ``'object'`` we check if the object is instance of a specific ``'object_instance'``, or in case ``'subclass=True'`` we check if the object is subclass if the ``'object_instance'``. - :param Object object: The object to check the inheritance + :param (iterable or class object) object: The object to check the inheritance :param Object object_instance: The parent class from where the object is expected to inherit :param str object_name: The name of the object @@ -25,12 +26,17 @@ def check_consistency(object, object_instance, object_name, subclass=False): :raises ValueError: If the object does not inherit from the specified class """ - if not subclass: - if not isinstance(object, object_instance): - raise ValueError(f"{object_name} must be {object_instance}") - else: - if not issubclass(object, object_instance): - raise ValueError(f"{object_name} must be {object_instance}") + if not isinstance(object, (list, set, tuple)): + object = [object] + + for obj in object: + try: + if not subclass: + assert isinstance(obj, object_instance) + else: + assert issubclass(obj, object_instance) + except AssertionError: + raise ValueError(f"{type(obj).__name__} must be {object_instance}.") def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check @@ -180,13 +186,13 @@ def chebyshev_roots(n): # def __len__(self): # return self._len -from torch.utils.data import Dataset, DataLoader + class LabelTensorDataset(Dataset): def __init__(self, d): for k, v in d.items(): setattr(self, k, v) self.labels = list(d.keys()) - + def __getitem__(self, index): print(index) result = {} @@ -201,7 +207,7 @@ class LabelTensorDataset(Dataset): result[label] = sample_tensor[index] except IndexError: result[label] = torch.tensor([]) - + print(result) return result @@ -229,13 +235,13 @@ class LabelTensorDataLoader(DataLoader): # def __len__(self): # return self._len -from torch.utils.data import Dataset, DataLoader + class LabelTensorDataset(Dataset): def __init__(self, d): for k, v in d.items(): setattr(self, k, v) self.labels = list(d.keys()) - + def __getitem__(self, index): print(index) result = {} @@ -250,7 +256,7 @@ class LabelTensorDataset(Dataset): result[label] = sample_tensor[index] except IndexError: result[label] = torch.tensor([]) - + print(result) return result @@ -261,4 +267,4 @@ class LabelTensorDataset(Dataset): class LabelTensorDataLoader(DataLoader): def collate_fn(self, data): - pass \ No newline at end of file + pass diff --git a/tests/test_utils.py b/tests/test_utils.py index 7ed795f..94895e5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,12 @@ import torch from pina.utils import merge_tensors from pina.label_tensor import LabelTensor +from pina import LabelTensor +from pina.geometry import EllipsoidDomain, CartesianDomain +from pina.utils import check_consistency +import pytest +from pina.geometry import Location + def test_merge_tensors(): tensor1 = LabelTensor(torch.rand((20, 3)), ['a', 'b', 'c']) @@ -9,7 +15,29 @@ def test_merge_tensors(): tensor3 = LabelTensor(torch.ones((30, 3)), ['g', 'h', 'i']) merged_tensor = merge_tensors((tensor1, tensor2, tensor3)) - assert tuple(merged_tensor.labels) == ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i') + assert tuple(merged_tensor.labels) == ( + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i') assert merged_tensor.shape == (20*20*30, 9) assert torch.all(merged_tensor.extract(('d', 'e', 'f')) == 0) assert torch.all(merged_tensor.extract(('g', 'h', 'i')) == 1) + + +def test_check_consistency_correct(): + ellipsoid1 = EllipsoidDomain({'x': [1, 2], 'y': [-2, 1]}) + example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z']) + + check_consistency(example_input_pts, torch.Tensor) + check_consistency(CartesianDomain, Location, subclass=True) + check_consistency(ellipsoid1, Location) + + +def test_check_consistency_incorrect(): + ellipsoid1 = EllipsoidDomain({'x': [1, 2], 'y': [-2, 1]}) + example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z']) + + with pytest.raises(ValueError): + check_consistency(example_input_pts, Location) + with pytest.raises(ValueError): + check_consistency(torch.Tensor, Location, subclass=True) + with pytest.raises(ValueError): + check_consistency(ellipsoid1, torch.Tensor)