edited utils to take list (#115)
* enhanced difference domain * refactored utils * fixed typo * added tests --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
@@ -3,11 +3,12 @@
|
|||||||
from .location import Location
|
from .location import Location
|
||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
class Difference(Location):
|
class Difference(Location):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
def __init__(self, first, second):
|
|
||||||
|
|
||||||
|
def __init__(self, first, second):
|
||||||
|
|
||||||
self.first = first
|
self.first = first
|
||||||
self.second = second
|
self.second = second
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ class Union(Location):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# union checks
|
# union checks
|
||||||
self._check_union_inheritance(geometries)
|
check_consistency(geometries, Location)
|
||||||
self._check_union_consistency(geometries)
|
self._check_union_dimensions(geometries)
|
||||||
|
|
||||||
# assign geometries
|
# assign geometries
|
||||||
self._geometries = geometries
|
self._geometries = geometries
|
||||||
@@ -116,7 +116,7 @@ class Union(Location):
|
|||||||
|
|
||||||
return LabelTensor(torch.cat(sampled_points), labels=[f'{i}' for i in self.variables])
|
return LabelTensor(torch.cat(sampled_points), labels=[f'{i}' for i in self.variables])
|
||||||
|
|
||||||
def _check_union_consistency(self, geometries):
|
def _check_union_dimensions(self, geometries):
|
||||||
"""Check if the dimensions of the geometries are consistent.
|
"""Check if the dimensions of the geometries are consistent.
|
||||||
|
|
||||||
:param geometries: Geometries to be checked.
|
:param geometries: Geometries to be checked.
|
||||||
@@ -126,12 +126,3 @@ class Union(Location):
|
|||||||
if geometry.variables != geometries[0].variables:
|
if geometry.variables != geometries[0].variables:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f'The geometries need to be the same dimensions. {geometry.variables} is not equal to {geometries[0].variables}')
|
f'The geometries need to be the same dimensions. {geometry.variables} is not equal to {geometries[0].variables}')
|
||||||
|
|
||||||
def _check_union_inheritance(self, geometries):
|
|
||||||
"""Check if the geometries are inherited from 'pina.geometry.Location'.
|
|
||||||
|
|
||||||
param geometries: Geometries to be checked.
|
|
||||||
:type geometries: list[Location]
|
|
||||||
"""
|
|
||||||
for idx, geometry in enumerate(geometries):
|
|
||||||
check_consistency(geometry, Location, f'geometry[{idx}]')
|
|
||||||
|
|||||||
@@ -108,9 +108,9 @@ class LpLoss(LossInterface):
|
|||||||
super().__init__(reduction=reduction)
|
super().__init__(reduction=reduction)
|
||||||
|
|
||||||
# check consistency
|
# check consistency
|
||||||
check_consistency(p, (str,int,float), 'degree p')
|
check_consistency(p, (str,int,float))
|
||||||
self.p = p
|
self.p = p
|
||||||
check_consistency(relative, bool, 'relative')
|
check_consistency(relative, bool)
|
||||||
self.relative = relative
|
self.relative = relative
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class Network(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# check model consistency
|
# check model consistency
|
||||||
check_consistency(model, nn.Module, 'torch model')
|
check_consistency(model, nn.Module)
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
||||||
# check consistency and assign extra fatures
|
# check consistency and assign extra fatures
|
||||||
@@ -17,7 +17,7 @@ class Network(torch.nn.Module):
|
|||||||
self._extra_features = []
|
self._extra_features = []
|
||||||
else:
|
else:
|
||||||
for feat in extra_features:
|
for feat in extra_features:
|
||||||
check_consistency(feat, nn.Module, 'extra features')
|
check_consistency(feat, nn.Module)
|
||||||
self._extra_features = nn.Sequential(*extra_features)
|
self._extra_features = nn.Sequential(*extra_features)
|
||||||
|
|
||||||
# check model works with inputs
|
# check model works with inputs
|
||||||
|
|||||||
10
pina/pinn.py
10
pina/pinn.py
@@ -48,11 +48,11 @@ class PINN(SolverInterface):
|
|||||||
super().__init__(model=model, problem=problem, extra_features=extra_features)
|
super().__init__(model=model, problem=problem, extra_features=extra_features)
|
||||||
|
|
||||||
# check consistency
|
# check consistency
|
||||||
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
|
check_consistency(optimizer, torch.optim.Optimizer, subclass=True)
|
||||||
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
|
check_consistency(optimizer_kwargs, dict)
|
||||||
check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True)
|
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||||
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
|
check_consistency(scheduler_kwargs, dict)
|
||||||
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)
|
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||||
|
|
||||||
# assign variables
|
# assign variables
|
||||||
self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs)
|
self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# check inheritance for pina problem
|
# check inheritance for pina problem
|
||||||
check_consistency(problem, AbstractProblem, 'pina problem')
|
check_consistency(problem, AbstractProblem)
|
||||||
|
|
||||||
# assigning class variables (check consistency inside Network class)
|
# assigning class variables (check consistency inside Network class)
|
||||||
self._pina_model = Network(model=model, extra_features=extra_features)
|
self._pina_model = Network(model=model, extra_features=extra_features)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class Trainer(pl.Trainer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# check inheritance consistency for solver
|
# check inheritance consistency for solver
|
||||||
check_consistency(solver, SolverInterface, 'Solver model')
|
check_consistency(solver, SolverInterface)
|
||||||
self._model = solver
|
self._model = solver
|
||||||
|
|
||||||
# create dataloader
|
# create dataloader
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Utils module"""
|
"""Utils module"""
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import types
|
import types
|
||||||
|
|
||||||
@@ -10,14 +11,14 @@ from .label_tensor import LabelTensor
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def check_consistency(object, object_instance, object_name, subclass=False):
|
def check_consistency(object, object_instance, subclass=False):
|
||||||
"""Helper function to check object inheritance consistency.
|
"""Helper function to check object inheritance consistency.
|
||||||
Given a specific ``'object'`` we check if the object is
|
Given a specific ``'object'`` we check if the object is
|
||||||
instance of a specific ``'object_instance'``, or in case
|
instance of a specific ``'object_instance'``, or in case
|
||||||
``'subclass=True'`` we check if the object is subclass
|
``'subclass=True'`` we check if the object is subclass
|
||||||
if the ``'object_instance'``.
|
if the ``'object_instance'``.
|
||||||
|
|
||||||
:param Object object: The object to check the inheritance
|
:param (iterable or class object) object: The object to check the inheritance
|
||||||
:param Object object_instance: The parent class from where the object
|
:param Object object_instance: The parent class from where the object
|
||||||
is expected to inherit
|
is expected to inherit
|
||||||
:param str object_name: The name of the object
|
:param str object_name: The name of the object
|
||||||
@@ -25,12 +26,17 @@ def check_consistency(object, object_instance, object_name, subclass=False):
|
|||||||
:raises ValueError: If the object does not inherit from the
|
:raises ValueError: If the object does not inherit from the
|
||||||
specified class
|
specified class
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(object, (list, set, tuple)):
|
||||||
|
object = [object]
|
||||||
|
|
||||||
|
for obj in object:
|
||||||
|
try:
|
||||||
if not subclass:
|
if not subclass:
|
||||||
if not isinstance(object, object_instance):
|
assert isinstance(obj, object_instance)
|
||||||
raise ValueError(f"{object_name} must be {object_instance}")
|
|
||||||
else:
|
else:
|
||||||
if not issubclass(object, object_instance):
|
assert issubclass(obj, object_instance)
|
||||||
raise ValueError(f"{object_name} must be {object_instance}")
|
except AssertionError:
|
||||||
|
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
|
||||||
|
|
||||||
|
|
||||||
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
||||||
@@ -180,7 +186,7 @@ def chebyshev_roots(n):
|
|||||||
# def __len__(self):
|
# def __len__(self):
|
||||||
# return self._len
|
# return self._len
|
||||||
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
class LabelTensorDataset(Dataset):
|
class LabelTensorDataset(Dataset):
|
||||||
def __init__(self, d):
|
def __init__(self, d):
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
@@ -229,7 +235,7 @@ class LabelTensorDataLoader(DataLoader):
|
|||||||
# def __len__(self):
|
# def __len__(self):
|
||||||
# return self._len
|
# return self._len
|
||||||
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
class LabelTensorDataset(Dataset):
|
class LabelTensorDataset(Dataset):
|
||||||
def __init__(self, d):
|
def __init__(self, d):
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
|
|||||||
@@ -2,6 +2,12 @@ import torch
|
|||||||
|
|
||||||
from pina.utils import merge_tensors
|
from pina.utils import merge_tensors
|
||||||
from pina.label_tensor import LabelTensor
|
from pina.label_tensor import LabelTensor
|
||||||
|
from pina import LabelTensor
|
||||||
|
from pina.geometry import EllipsoidDomain, CartesianDomain
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
import pytest
|
||||||
|
from pina.geometry import Location
|
||||||
|
|
||||||
|
|
||||||
def test_merge_tensors():
|
def test_merge_tensors():
|
||||||
tensor1 = LabelTensor(torch.rand((20, 3)), ['a', 'b', 'c'])
|
tensor1 = LabelTensor(torch.rand((20, 3)), ['a', 'b', 'c'])
|
||||||
@@ -9,7 +15,29 @@ def test_merge_tensors():
|
|||||||
tensor3 = LabelTensor(torch.ones((30, 3)), ['g', 'h', 'i'])
|
tensor3 = LabelTensor(torch.ones((30, 3)), ['g', 'h', 'i'])
|
||||||
|
|
||||||
merged_tensor = merge_tensors((tensor1, tensor2, tensor3))
|
merged_tensor = merge_tensors((tensor1, tensor2, tensor3))
|
||||||
assert tuple(merged_tensor.labels) == ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i')
|
assert tuple(merged_tensor.labels) == (
|
||||||
|
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i')
|
||||||
assert merged_tensor.shape == (20*20*30, 9)
|
assert merged_tensor.shape == (20*20*30, 9)
|
||||||
assert torch.all(merged_tensor.extract(('d', 'e', 'f')) == 0)
|
assert torch.all(merged_tensor.extract(('d', 'e', 'f')) == 0)
|
||||||
assert torch.all(merged_tensor.extract(('g', 'h', 'i')) == 1)
|
assert torch.all(merged_tensor.extract(('g', 'h', 'i')) == 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_consistency_correct():
|
||||||
|
ellipsoid1 = EllipsoidDomain({'x': [1, 2], 'y': [-2, 1]})
|
||||||
|
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||||
|
|
||||||
|
check_consistency(example_input_pts, torch.Tensor)
|
||||||
|
check_consistency(CartesianDomain, Location, subclass=True)
|
||||||
|
check_consistency(ellipsoid1, Location)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_consistency_incorrect():
|
||||||
|
ellipsoid1 = EllipsoidDomain({'x': [1, 2], 'y': [-2, 1]})
|
||||||
|
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
check_consistency(example_input_pts, Location)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
check_consistency(torch.Tensor, Location, subclass=True)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
check_consistency(ellipsoid1, torch.Tensor)
|
||||||
|
|||||||
Reference in New Issue
Block a user