Codacy correction

This commit is contained in:
FilippoOlivo
2024-10-31 09:50:19 +01:00
committed by Nicola Demo
parent ea3d1924e7
commit dd43c8304c
23 changed files with 246 additions and 214 deletions

View File

@@ -1,7 +1,6 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition",
"SamplePointDataset", "PinaDataModule", "PinaDataLoader",
'TorchOptimizer', 'Graph'
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
]
from .meta import *
@@ -15,4 +14,4 @@ from .data import PinaDataModule
from .data import PinaDataLoader
from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
from .graph import Graph

View File

@@ -2,23 +2,28 @@ from .utils import check_consistency, merge_tensors
class Collector:
def __init__(self, problem):
# creating a hook between collector and problem
self.problem = problem
# this variable is used to store the data in the form:
# {'[condition_name]' :
# {'input_points' : Tensor,
# {'[condition_name]' :
# {'input_points' : Tensor,
# '[equation/output_points/conditional_variables]': Tensor}
# }
# those variables are used for the dataloading
self._data_collections = {name: {} for name in self.problem.conditions}
self.conditions_name = {i: name for i, name in
enumerate(self.problem.conditions)}
self.conditions_name = {
i: name
for i, name in enumerate(self.problem.conditions)
}
# variables used to check that all conditions are sampled
self._is_conditions_ready = {
name: False for name in self.problem.conditions}
name: False
for name in self.problem.conditions
}
self.full = False
@property
@@ -47,8 +52,8 @@ class Collector:
for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self._is_conditions_ready[condition_name]) and (
not hasattr(condition, "domain")):
if (not self._is_conditions_ready[condition_name]) and (not hasattr(
condition, "domain")):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
@@ -70,7 +75,8 @@ class Collector:
# if we have sampled the condition but not all variables
else:
already_sampled = [
self.data_collections[loc]['input_points']]
self.data_collections[loc]['input_points']
]
# if the condition is ready but we want to sample again
else:
self._is_conditions_ready[loc] = False
@@ -78,14 +84,10 @@ class Collector:
# get the samples
samples = [
condition.domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (
set(pts.labels).issubset(
sorted(self.problem.input_variables))
):
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True

View File

@@ -39,21 +39,16 @@ class Condition:
"""
__slots__ = list(
set(
InputOutputPointsCondition.__slots__ +
set(InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__
)
)
DataConditionInterface.__slots__))
def __new__(cls, *args, **kwargs):
if len(args) != 0:
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
)
raise ValueError("Condition takes only the following keyword "
f"arguments: {Condition.__slots__}.")
sorted_keys = sorted(kwargs.keys())
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):

View File

@@ -1,6 +1,6 @@
from abc import ABCMeta
class ConditionInterface(metaclass=ABCMeta):
condition_types = ['physics', 'supervised', 'unsupervised']
@@ -12,7 +12,7 @@ class ConditionInterface(metaclass=ABCMeta):
@property
def problem(self):
return self._problem
@problem.setter
def problem(self, value):
self._problem = value
@@ -20,15 +20,14 @@ class ConditionInterface(metaclass=ABCMeta):
@property
def condition_type(self):
return self._condition_type
@condition_type.setter
def condition_type(self, values):
if not isinstance(values, (list, tuple)):
values = [values]
for value in values:
if value not in ConditionInterface.condition_types:
raise ValueError(
'Unavailable type of condition, expected one of'
f' {ConditionInterface.condition_types}.'
)
self._condition_type = values
raise ValueError(
'Unavailable type of condition, expected one of'
f' {ConditionInterface.condition_types}.')
self._condition_type = values

View File

@@ -5,6 +5,7 @@ from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
class DataConditionInterface(ConditionInterface):
"""
Condition for data. This condition must be used every
@@ -29,4 +30,4 @@ class DataConditionInterface(ConditionInterface):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
super().__setattr__(key, value)

View File

@@ -5,6 +5,7 @@ from ..utils import check_consistency
from ..domain import DomainInterface
from ..equation.equation_interface import EquationInterface
class DomainEquationCondition(ConditionInterface):
"""
Condition for domain/equation data. This condition must be used every
@@ -30,4 +31,4 @@ class DomainEquationCondition(ConditionInterface):
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
super().__setattr__(key, value)

View File

@@ -6,6 +6,7 @@ from ..graph import Graph
from ..utils import check_consistency
from ..equation.equation_interface import EquationInterface
class InputPointsEquationCondition(ConditionInterface):
"""
Condition for input_points/equation data. This condition must be used every
@@ -25,10 +26,12 @@ class InputPointsEquationCondition(ConditionInterface):
def __setattr__(self, key, value):
if key == 'input_points':
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
check_consistency(
value, (LabelTensor)
) # for now only labeltensors, we need labels for the operators!
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
super().__setattr__(key, value)

View File

@@ -1,4 +1,3 @@
import torch
from .condition_interface import ConditionInterface
@@ -6,6 +5,7 @@ from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
class InputOutputPointsCondition(ConditionInterface):
"""
Condition for domain/equation data. This condition must be used every

View File

@@ -59,9 +59,11 @@ class BaseDataset(Dataset):
keys = list(data.keys())
if set(self.__slots__) == set(keys):
self._populate_init_list(data)
idx = [key for key, val in
self.problem.collector.conditions_name.items() if
val == name]
idx = [
key for key, val in
self.problem.collector.conditions_name.items()
if val == name
]
self.conditions_idx.append(idx)
self.initialize()
@@ -89,15 +91,16 @@ class BaseDataset(Dataset):
if isinstance(slot_data, (LabelTensor, torch.Tensor)):
dims = len(slot_data.size())
slot_data = slot_data.permute(
[batching_dim] + [dim for dim in range(dims) if
dim != batching_dim])
[batching_dim] +
[dim for dim in range(dims) if dim != batching_dim])
if current_cond_num_el is None:
current_cond_num_el = len(slot_data)
elif current_cond_num_el != len(slot_data):
raise ValueError('Different dimension in same condition')
current_list = getattr(self, slot)
current_list += [slot_data] if not (
isinstance(slot_data, list)) else slot_data
current_list += [
slot_data
] if not (isinstance(slot_data, list)) else slot_data
self.num_el_per_condition.append(current_cond_num_el)
def initialize(self):
@@ -108,14 +111,12 @@ class BaseDataset(Dataset):
logging.debug(f'Initialize dataset {self.__class__.__name__}')
if self.num_el_per_condition:
self.condition_indices = torch.cat(
[
torch.tensor([i] * self.num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(self.num_el_per_condition))
],
dim=0
)
self.condition_indices = torch.cat([
torch.tensor([i] * self.num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(self.num_el_per_condition))
],
dim=0)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
if all(isinstance(a, LabelTensor) for a in current_attribute):

View File

@@ -44,8 +44,9 @@ class PinaDataModule(LightningDataModule):
super().__init__()
self.problem = problem
self.device = device
self.dataset_classes = [SupervisedDataset, UnsupervisedDataset,
SamplePointDataset]
self.dataset_classes = [
SupervisedDataset, UnsupervisedDataset, SamplePointDataset
]
if datasets is None:
self.datasets = None
else:
@@ -71,15 +72,12 @@ class PinaDataModule(LightningDataModule):
self.split_length.append(val_size)
self.split_names.append('val')
self.loader_functions['val_dataloader'] = lambda: PinaDataLoader(
self.splits['val'], self.batch_size,
self.condition_names)
self.splits['val'], self.batch_size, self.condition_names)
if predict_size > 0:
self.split_length.append(predict_size)
self.split_names.append('predict')
self.loader_functions[
'predict_dataloader'] = lambda: PinaDataLoader(
self.splits['predict'], self.batch_size,
self.condition_names)
self.loader_functions['predict_dataloader'] = lambda: PinaDataLoader(
self.splits['predict'], self.batch_size, self.condition_names)
self.splits = {k: {} for k in self.split_names}
self.shuffle = shuffle
@@ -104,8 +102,8 @@ class PinaDataModule(LightningDataModule):
self.split_length,
shuffle=self.shuffle)
for i in range(len(self.split_length)):
self.splits[
self.split_names[i]][dataset.data_type] = splits[i]
self.splits[self.split_names[i]][
dataset.data_type] = splits[i]
elif stage == 'test':
raise NotImplementedError("Testing pipeline not implemented yet")
else:
@@ -137,14 +135,12 @@ class PinaDataModule(LightningDataModule):
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths),
generator=generator)
indices = torch.randperm(sum(lengths), generator=generator)
else:
indices = torch.randperm(sum(lengths))
dataset.apply_shuffle(indices)
indices = torch.arange(0, sum(lengths), 1,
dtype=torch.uint8).tolist()
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
@@ -161,13 +157,16 @@ class PinaDataModule(LightningDataModule):
collector = self.problem.collector
batching_dim = self.problem.batching_dimension
datasets_slots = [i.__slots__ for i in self.dataset_classes]
self.datasets = [dataset(device=self.device) for dataset in
self.dataset_classes]
self.datasets = [
dataset(device=self.device) for dataset in self.dataset_classes
]
logging.debug('Filling datasets in PinaDataModule obj')
for name, data in collector.data_collections.items():
keys = list(data.keys())
idx = [key for key, val in collector.conditions_name.items() if
val == name]
idx = [
key for key, val in collector.conditions_name.items()
if val == name
]
for i, slot in enumerate(datasets_slots):
if slot == keys:
self.datasets[i].add_points(data, idx[0], batching_dim)

View File

@@ -37,14 +37,11 @@ class Batch:
if item in super().__getattribute__('attributes'):
dataset = super().__getattribute__(item)
index = super().__getattribute__(item + '_idx')
return PinaSubset(
dataset.dataset,
dataset.indices[index])
else:
return super().__getattribute__(item)
return PinaSubset(dataset.dataset, dataset.indices[index])
return super().__getattribute__(item)
def __getattr__(self, item):
if item == 'data' and len(self.attributes) == 1:
item = self.attributes[0]
return super().__getattribute__(item)
raise AttributeError(f"'Batch' object has no attribute '{item}'")
raise AttributeError(f"'Batch' object has no attribute '{item}'")

View File

@@ -19,15 +19,17 @@ class SamplePointDataset(BaseDataset):
data_dict.pop('equation')
super().add_points(data_dict, condition_idx)
def _init_from_problem(self, collector_dict, batching_dim=0):
def _init_from_problem(self, collector_dict):
for name, data in collector_dict.items():
keys = list(data.keys())
if set(self.__slots__) == set(keys):
data = deepcopy(data)
data.pop('equation')
self._populate_init_list(data)
idx = [key for key, val in
self.problem.collector.conditions_name.items() if
val == name]
idx = [
key for key, val in
self.problem.collector.conditions_name.items()
if val == name
]
self.conditions_idx.append(idx)
self.initialize()

View File

@@ -168,9 +168,8 @@ class CartesianDomain(DomainInterface):
for variable in variables:
if variable in self.fixed_.keys():
value = self.fixed_[variable]
pts_variable = torch.tensor([[value]]).repeat(
result.shape[0], 1
)
pts_variable = torch.tensor([[value]
]).repeat(result.shape[0], 1)
pts_variable = pts_variable.as_subclass(LabelTensor)
pts_variable.labels = [variable]
@@ -203,9 +202,8 @@ class CartesianDomain(DomainInterface):
for variable in variables:
if variable in self.fixed_.keys():
value = self.fixed_[variable]
pts_variable = torch.tensor([[value]]).repeat(
result.shape[0], 1
)
pts_variable = torch.tensor([[value]
]).repeat(result.shape[0], 1)
pts_variable = pts_variable.as_subclass(LabelTensor)
pts_variable.labels = [variable]

View File

@@ -38,8 +38,7 @@ class DomainInterface(metaclass=ABCMeta):
if value not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {value} not valid. Expected at least "
"one in "
f"{DomainInterface.available_sampling_modes}."
)
f"{DomainInterface.available_sampling_modes}.")
@abstractmethod
def sample(self):

View File

@@ -1,5 +1,5 @@
""" Module for LabelTensor """
from copy import copy
from copy import copy, deepcopy
import torch
from torch import Tensor
@@ -10,17 +10,16 @@ def issubset(a, b):
"""
if isinstance(a, list) and isinstance(b, list):
return set(a).issubset(set(b))
elif isinstance(a, range) and isinstance(b, range):
if isinstance(a, range) and isinstance(b, range):
return a.start <= b.start and a.stop >= b.stop
else:
return False
return False
class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
@staticmethod
def __new__(cls, x, labels, full=True, *args, **kwargs):
def __new__(cls, x, labels, *args, **kwargs):
if isinstance(x, LabelTensor):
return x
else:
@@ -30,7 +29,7 @@ class LabelTensor(torch.Tensor):
def tensor(self):
return self.as_subclass(Tensor)
def __init__(self, x, labels, full=False):
def __init__(self, x, labels, **kwargs):
"""
Construct a `LabelTensor` by passing a dict of the labels
@@ -42,14 +41,19 @@ class LabelTensor(torch.Tensor):
"""
self.dim_names = None
self.full = full
self.full = kwargs.get('full', True)
self.labels = labels
@classmethod
def __internal_init__(cls, x, labels, dim_names ,full=False, *args, **kwargs):
lt = cls.__new__(cls, x, labels, full, *args, **kwargs)
def __internal_init__(cls,
x,
labels,
dim_names,
*args,
**kwargs):
lt = cls.__new__(cls, x, labels, *args, **kwargs)
lt._labels = labels
lt.full = full
lt.full = kwargs.get('full', True)
lt.dim_names = dim_names
return lt
@@ -122,8 +126,12 @@ class LabelTensor(torch.Tensor):
tensor_shape = self.shape
if hasattr(self, 'full') and self.full:
labels = {i: labels[i] if i in labels else {'name': i} for i in
labels.keys()}
labels = {
i: labels[i] if i in labels else {
'name': i
}
for i in labels.keys()
}
for k, v in labels.items():
# Init labels from str
if isinstance(v, str):
@@ -133,8 +141,8 @@ class LabelTensor(torch.Tensor):
# Init from dict with only name key
v['dof'] = range(tensor_shape[k])
# Init from dict with both name and dof keys
elif isinstance(v, dict) and sorted(list(v.keys())) == ['dof',
'name']:
elif isinstance(v, dict) and sorted(list(
v.keys())) == ['dof', 'name']:
dof_list = v['dof']
dof_len = len(dof_list)
if dof_len != len(set(dof_list)):
@@ -143,7 +151,7 @@ class LabelTensor(torch.Tensor):
raise ValueError(
'Number of dof does not match tensor shape')
else:
ValueError('Illegal labels initialization')
raise ValueError('Illegal labels initialization')
# Perform update
self._labels[k] = v
@@ -157,7 +165,11 @@ class LabelTensor(torch.Tensor):
"""
# Create a dict with labels
last_dim_labels = {
self.ndim - 1: {'dof': labels, 'name': self.ndim - 1}}
self.ndim - 1: {
'dof': labels,
'name': self.ndim - 1
}
}
self._init_labels_from_dict(last_dim_labels)
def set_names(self):
@@ -217,9 +229,10 @@ class LabelTensor(torch.Tensor):
v = [v] if isinstance(v, (int, str)) else v
if not isinstance(v, range):
extractor[idx_dim] = [dim_labels.index(i) for i in v] if len(
v) > 1 else slice(dim_labels.index(v[0]),
dim_labels.index(v[0]) + 1)
extractor[idx_dim] = [dim_labels.index(i)
for i in v] if len(v) > 1 else slice(
dim_labels.index(v[0]),
dim_labels.index(v[0]) + 1)
else:
extractor[idx_dim] = slice(v.start, v.stop)
@@ -263,10 +276,10 @@ class LabelTensor(torch.Tensor):
new_tensor = torch.cat(tensors, dim=dim)
# Update labels
labels = LabelTensor.__create_labels_cat(tensors,
dim)
labels = LabelTensor.__create_labels_cat(tensors, dim)
return LabelTensor.__internal_init__(new_tensor, labels, tensors[0].dim_names)
return LabelTensor.__internal_init__(new_tensor, labels,
tensors[0].dim_names)
@staticmethod
def __create_labels_cat(tensors, dim):
@@ -277,9 +290,10 @@ class LabelTensor(torch.Tensor):
# check if:
# - labels dict have same keys
# - all labels are the same expect for dimension dim
if not all(all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
if not all(
all(stored_labels[i][k] == stored_labels[0][k]
for i in range(len(stored_labels)))
for k in stored_labels[0].keys() if k != dim):
raise RuntimeError('tensors must have the same shape and dof')
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
@@ -341,8 +355,12 @@ class LabelTensor(torch.Tensor):
last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)]
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
labels.update({tensors[0].ndim - 1: {'dof': last_dim_labels,
'name': tensors[0].name}})
labels.update({
tensors[0].ndim - 1: {
'dof': last_dim_labels,
'name': tensors[0].name
}
})
return LabelTensor(data, labels)
def append(self, tensor, mode='std'):
@@ -384,8 +402,9 @@ class LabelTensor(torch.Tensor):
:param index:
:return:
"""
if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(
isinstance(a, str) for a in index)):
if isinstance(index,
str) or (isinstance(index, (tuple, list))
and all(isinstance(a, str) for a in index)):
return self.extract(index)
selected_lt = super().__getitem__(index)
@@ -418,21 +437,31 @@ class LabelTensor(torch.Tensor):
:return:
"""
old_dof = old_labels[dim]['dof']
if not isinstance(index, (int, slice)) and len(index) == len(
old_dof) and isinstance(old_dof, range):
if not isinstance(
index,
(int, slice)) and len(index) == len(old_dof) and isinstance(
old_dof, range):
return
if isinstance(index, torch.Tensor):
index = index.nonzero(as_tuple=True)[
0] if index.dtype == torch.bool else index.tolist()
index = index.nonzero(
as_tuple=True
)[0] if index.dtype == torch.bool else index.tolist()
if isinstance(index, list):
to_update_labels.update({dim: {
'dof': [old_dof[i] for i in index],
'name': old_labels[dim]['name']}})
to_update_labels.update({
dim: {
'dof': [old_dof[i] for i in index],
'name': old_labels[dim]['name']
}
})
else:
to_update_labels.update({dim: {'dof': old_dof[index],
'name': old_labels[dim]['name']}})
to_update_labels.update(
{dim: {
'dof': old_dof[index],
'name': old_labels[dim]['name']
}})
def sort_labels(self, dim=None):
def arg_sort(lst):
return sorted(range(len(lst)), key=lambda x: lst[x])
@@ -445,7 +474,6 @@ class LabelTensor(torch.Tensor):
return self.__getitem__(indexer)
def __deepcopy__(self, memo):
from copy import deepcopy
cls = self.__class__
result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
return result
@@ -454,6 +482,8 @@ class LabelTensor(torch.Tensor):
tensor = super().permute(*dims)
stored_labels = self.stored_labels
keys_list = list(*dims)
labels = {keys_list.index(k): copy(stored_labels[k]) for k in
stored_labels.keys()}
labels = {
keys_list.index(k): copy(stored_labels[k])
for k in stored_labels.keys()
}
return LabelTensor.__internal_init__(tensor, labels, self.dim_names)

View File

@@ -56,9 +56,9 @@ def grad(output_, input_, components=None, d=None):
gradients = torch.autograd.grad(
output_,
input_,
grad_outputs=torch.ones(
output_.size(), dtype=output_.dtype, device=output_.device
),
grad_outputs=torch.ones(output_.size(),
dtype=output_.dtype,
device=output_.device),
create_graph=True,
retain_graph=True,
allow_unused=True,
@@ -85,8 +85,8 @@ def grad(output_, input_, components=None, d=None):
raise RuntimeError
gradients = grad_scalar_output(output_, input_, d)
elif output_.shape[
output_.ndim - 1] >= 2: # vector output ##############################
elif output_.shape[output_.ndim -
1] >= 2: # vector output ##############################
tensor_to_cat = []
for i, c in enumerate(components):
c_output = output_.extract([c])
@@ -281,11 +281,8 @@ def advection(output_, input_, velocity_field, components=None, d=None):
if components is None:
components = output_.labels
tmp = (
grad(output_, input_, components, d)
.reshape(-1, len(components), len(d))
.transpose(0, 1)
)
tmp = (grad(output_, input_, components, d).reshape(-1, len(components),
len(d)).transpose(0, 1))
tmp *= output_.extract(velocity_field)
return tmp.sum(dim=2).T

View File

@@ -7,6 +7,7 @@ from ..condition.domain_equation_condition import DomainEquationCondition
from ..collector import Collector
from copy import deepcopy
class AbstractProblem(metaclass=ABCMeta):
"""
The abstract `AbstractProblem` class. All the class defining a PINA Problem
@@ -55,7 +56,7 @@ class AbstractProblem(metaclass=ABCMeta):
if 'input_points' in v.keys():
to_return[k] = v['input_points']
return to_return
def __deepcopy__(self, memo):
"""
Implements deepcopy for the
@@ -116,9 +117,11 @@ class AbstractProblem(metaclass=ABCMeta):
"""
return self._conditions
def discretise_domain(
self, n, mode="random", variables="all", locations="all"
):
def discretise_domain(self,
n,
mode="random",
variables="all",
locations="all"):
"""
Generate a set of points to span the `Location` of all the conditions of
the problem.
@@ -170,9 +173,10 @@ class AbstractProblem(metaclass=ABCMeta):
# check correct location
if locations == "all":
locations = [name for name in self.conditions.keys()
if isinstance(self.conditions[name],
DomainEquationCondition)]
locations = [
name for name in self.conditions.keys()
if isinstance(self.conditions[name], DomainEquationCondition)
]
else:
if not isinstance(locations, (list)):
locations = [locations]

View File

@@ -142,5 +142,4 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
condition.condition_type):
raise ValueError(
f'{self.__name__} support only dose not support condition '
f'{condition.condition_type}'
)
f'{condition.condition_type}')

View File

@@ -9,8 +9,13 @@ from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
eval_size=.1, **kwargs):
def __init__(self,
solver,
batch_size=None,
train_size=.7,
test_size=.2,
eval_size=.1,
**kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
@@ -48,8 +53,7 @@ class Trainer(pytorch_lightning.Trainer):
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device)
)
pb.unknown_parameters[key].data.to(device))
def _create_loader(self):
"""
@@ -58,14 +62,11 @@ class Trainer(pytorch_lightning.Trainer):
trainer dataloader, just call the method.
"""
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}"""
for key, value in
self._solver.problem.collector._is_conditions_ready.items()
]
)
error_message = '\n'.join([
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}""" for key, value in
self._solver.problem.collector._is_conditions_ready.items()
])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
@@ -76,7 +77,8 @@ class Trainer(pytorch_lightning.Trainer):
device = devices[0]
data_module = PinaDataModule(problem=self.solver.problem, device=device,
data_module = PinaDataModule(problem=self.solver.problem,
device=device,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.eval_size)
@@ -87,9 +89,9 @@ class Trainer(pytorch_lightning.Trainer):
"""
Train the solver method.
"""
return super().fit(
self.solver, train_dataloaders=self._loader, **kwargs
)
return super().fit(self.solver,
train_dataloaders=self._loader,
**kwargs)
@property
def solver(self):

View File

@@ -40,9 +40,9 @@ def check_consistency(object, object_instance, subclass=False):
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
def number_parameters(
model, aggregate=True, only_trainable=True
): # TODO: check
def number_parameters(model,
aggregate=True,
only_trainable=True): # TODO: check
"""
Return the number of parameters of a given `model`.
@@ -80,9 +80,8 @@ def merge_two_tensors(tensor1, tensor2):
n2 = tensor2.shape[0]
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
)
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
return tensor1.append(tensor2)

View File

@@ -22,8 +22,11 @@ def test_init_inputoutput():
Condition(input_points=3., output_points='example')
with pytest.raises(ValueError):
Condition(input_points=example_domain, output_points=example_domain)
test_init_inputoutput()
def test_init_domainfunc():
Condition(domain=example_domain, equation=FixedValue(0.0))
with pytest.raises(ValueError):

View File

@@ -32,49 +32,49 @@ class Poisson(SpatialProblem):
conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_),
Condition(input_points=in_, output_points=out_),
'data2':
Condition(input_points=in2_, output_points=out2_),
Condition(input_points=in2_, output_points=out2_),
'unsupervised':
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
'unsupervised2':
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
}
@@ -201,11 +201,12 @@ data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
class GraphProblem(AbstractProblem):
output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
input = [Graph.build('radius',
nodes_coordinates=coordinates[i, :, :],
nodes_data=data[i, :, :], radius=0.2)
for i in
range(100)]
input = [
Graph.build('radius',
nodes_coordinates=coordinates[i, :, :],
nodes_data=data[i, :, :],
radius=0.2) for i in range(100)
]
output_variables = ['u']
conditions = {

View File

@@ -3,6 +3,7 @@ import torch
from pina import LabelTensor
from pina.domain import CartesianDomain
def test_constructor():
CartesianDomain({'x': [0, 1], 'y': [0, 1]})