Codacy correction
This commit is contained in:
committed by
Nicola Demo
parent
ea3d1924e7
commit
dd43c8304c
@@ -1,7 +1,6 @@
|
||||
__all__ = [
|
||||
"Trainer", "LabelTensor", "Plotter", "Condition",
|
||||
"SamplePointDataset", "PinaDataModule", "PinaDataLoader",
|
||||
'TorchOptimizer', 'Graph'
|
||||
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
|
||||
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
|
||||
]
|
||||
|
||||
from .meta import *
|
||||
|
||||
@@ -2,6 +2,7 @@ from .utils import check_consistency, merge_tensors
|
||||
|
||||
|
||||
class Collector:
|
||||
|
||||
def __init__(self, problem):
|
||||
# creating a hook between collector and problem
|
||||
self.problem = problem
|
||||
@@ -13,12 +14,16 @@ class Collector:
|
||||
# }
|
||||
# those variables are used for the dataloading
|
||||
self._data_collections = {name: {} for name in self.problem.conditions}
|
||||
self.conditions_name = {i: name for i, name in
|
||||
enumerate(self.problem.conditions)}
|
||||
self.conditions_name = {
|
||||
i: name
|
||||
for i, name in enumerate(self.problem.conditions)
|
||||
}
|
||||
|
||||
# variables used to check that all conditions are sampled
|
||||
self._is_conditions_ready = {
|
||||
name: False for name in self.problem.conditions}
|
||||
name: False
|
||||
for name in self.problem.conditions
|
||||
}
|
||||
self.full = False
|
||||
|
||||
@property
|
||||
@@ -47,8 +52,8 @@ class Collector:
|
||||
for condition_name, condition in self.problem.conditions.items():
|
||||
# if the condition is not ready and domain is not attribute
|
||||
# of condition, we get and store the data
|
||||
if (not self._is_conditions_ready[condition_name]) and (
|
||||
not hasattr(condition, "domain")):
|
||||
if (not self._is_conditions_ready[condition_name]) and (not hasattr(
|
||||
condition, "domain")):
|
||||
# get data
|
||||
keys = condition.__slots__
|
||||
values = [getattr(condition, name) for name in keys]
|
||||
@@ -70,7 +75,8 @@ class Collector:
|
||||
# if we have sampled the condition but not all variables
|
||||
else:
|
||||
already_sampled = [
|
||||
self.data_collections[loc]['input_points']]
|
||||
self.data_collections[loc]['input_points']
|
||||
]
|
||||
# if the condition is ready but we want to sample again
|
||||
else:
|
||||
self._is_conditions_ready[loc] = False
|
||||
@@ -78,14 +84,10 @@ class Collector:
|
||||
|
||||
# get the samples
|
||||
samples = [
|
||||
condition.domain.sample(n=n, mode=mode,
|
||||
variables=variables)
|
||||
condition.domain.sample(n=n, mode=mode, variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
if (
|
||||
set(pts.labels).issubset(
|
||||
sorted(self.problem.input_variables))
|
||||
):
|
||||
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
|
||||
pts = pts.sort_labels()
|
||||
if sorted(pts.labels) == sorted(self.problem.input_variables):
|
||||
self._is_conditions_ready[loc] = True
|
||||
|
||||
@@ -39,21 +39,16 @@ class Condition:
|
||||
"""
|
||||
|
||||
__slots__ = list(
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__ +
|
||||
set(InputOutputPointsCondition.__slots__ +
|
||||
InputPointsEquationCondition.__slots__ +
|
||||
DomainEquationCondition.__slots__ +
|
||||
DataConditionInterface.__slots__
|
||||
)
|
||||
)
|
||||
DataConditionInterface.__slots__))
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
||||
if len(args) != 0:
|
||||
raise ValueError(
|
||||
"Condition takes only the following keyword "
|
||||
f"arguments: {Condition.__slots__}."
|
||||
)
|
||||
raise ValueError("Condition takes only the following keyword "
|
||||
f"arguments: {Condition.__slots__}.")
|
||||
|
||||
sorted_keys = sorted(kwargs.keys())
|
||||
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
from abc import ABCMeta
|
||||
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
|
||||
condition_types = ['physics', 'supervised', 'unsupervised']
|
||||
@@ -29,6 +29,5 @@ class ConditionInterface(metaclass=ABCMeta):
|
||||
if value not in ConditionInterface.condition_types:
|
||||
raise ValueError(
|
||||
'Unavailable type of condition, expected one of'
|
||||
f' {ConditionInterface.condition_types}.'
|
||||
)
|
||||
f' {ConditionInterface.condition_types}.')
|
||||
self._condition_type = values
|
||||
@@ -5,6 +5,7 @@ from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class DataConditionInterface(ConditionInterface):
|
||||
"""
|
||||
Condition for data. This condition must be used every
|
||||
|
||||
@@ -5,6 +5,7 @@ from ..utils import check_consistency
|
||||
from ..domain import DomainInterface
|
||||
from ..equation.equation_interface import EquationInterface
|
||||
|
||||
|
||||
class DomainEquationCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for domain/equation data. This condition must be used every
|
||||
|
||||
@@ -6,6 +6,7 @@ from ..graph import Graph
|
||||
from ..utils import check_consistency
|
||||
from ..equation.equation_interface import EquationInterface
|
||||
|
||||
|
||||
class InputPointsEquationCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for input_points/equation data. This condition must be used every
|
||||
@@ -25,7 +26,9 @@ class InputPointsEquationCondition(ConditionInterface):
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'input_points':
|
||||
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
||||
check_consistency(
|
||||
value, (LabelTensor)
|
||||
) # for now only labeltensors, we need labels for the operators!
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import torch
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
@@ -6,6 +5,7 @@ from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class InputOutputPointsCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for domain/equation data. This condition must be used every
|
||||
|
||||
@@ -59,9 +59,11 @@ class BaseDataset(Dataset):
|
||||
keys = list(data.keys())
|
||||
if set(self.__slots__) == set(keys):
|
||||
self._populate_init_list(data)
|
||||
idx = [key for key, val in
|
||||
self.problem.collector.conditions_name.items() if
|
||||
val == name]
|
||||
idx = [
|
||||
key for key, val in
|
||||
self.problem.collector.conditions_name.items()
|
||||
if val == name
|
||||
]
|
||||
self.conditions_idx.append(idx)
|
||||
self.initialize()
|
||||
|
||||
@@ -89,15 +91,16 @@ class BaseDataset(Dataset):
|
||||
if isinstance(slot_data, (LabelTensor, torch.Tensor)):
|
||||
dims = len(slot_data.size())
|
||||
slot_data = slot_data.permute(
|
||||
[batching_dim] + [dim for dim in range(dims) if
|
||||
dim != batching_dim])
|
||||
[batching_dim] +
|
||||
[dim for dim in range(dims) if dim != batching_dim])
|
||||
if current_cond_num_el is None:
|
||||
current_cond_num_el = len(slot_data)
|
||||
elif current_cond_num_el != len(slot_data):
|
||||
raise ValueError('Different dimension in same condition')
|
||||
current_list = getattr(self, slot)
|
||||
current_list += [slot_data] if not (
|
||||
isinstance(slot_data, list)) else slot_data
|
||||
current_list += [
|
||||
slot_data
|
||||
] if not (isinstance(slot_data, list)) else slot_data
|
||||
self.num_el_per_condition.append(current_cond_num_el)
|
||||
|
||||
def initialize(self):
|
||||
@@ -108,14 +111,12 @@ class BaseDataset(Dataset):
|
||||
logging.debug(f'Initialize dataset {self.__class__.__name__}')
|
||||
|
||||
if self.num_el_per_condition:
|
||||
self.condition_indices = torch.cat(
|
||||
[
|
||||
self.condition_indices = torch.cat([
|
||||
torch.tensor([i] * self.num_el_per_condition[i],
|
||||
dtype=torch.uint8)
|
||||
for i in range(len(self.num_el_per_condition))
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
dim=0)
|
||||
for slot in self.__slots__:
|
||||
current_attribute = getattr(self, slot)
|
||||
if all(isinstance(a, LabelTensor) for a in current_attribute):
|
||||
|
||||
@@ -44,8 +44,9 @@ class PinaDataModule(LightningDataModule):
|
||||
super().__init__()
|
||||
self.problem = problem
|
||||
self.device = device
|
||||
self.dataset_classes = [SupervisedDataset, UnsupervisedDataset,
|
||||
SamplePointDataset]
|
||||
self.dataset_classes = [
|
||||
SupervisedDataset, UnsupervisedDataset, SamplePointDataset
|
||||
]
|
||||
if datasets is None:
|
||||
self.datasets = None
|
||||
else:
|
||||
@@ -71,15 +72,12 @@ class PinaDataModule(LightningDataModule):
|
||||
self.split_length.append(val_size)
|
||||
self.split_names.append('val')
|
||||
self.loader_functions['val_dataloader'] = lambda: PinaDataLoader(
|
||||
self.splits['val'], self.batch_size,
|
||||
self.condition_names)
|
||||
self.splits['val'], self.batch_size, self.condition_names)
|
||||
if predict_size > 0:
|
||||
self.split_length.append(predict_size)
|
||||
self.split_names.append('predict')
|
||||
self.loader_functions[
|
||||
'predict_dataloader'] = lambda: PinaDataLoader(
|
||||
self.splits['predict'], self.batch_size,
|
||||
self.condition_names)
|
||||
self.loader_functions['predict_dataloader'] = lambda: PinaDataLoader(
|
||||
self.splits['predict'], self.batch_size, self.condition_names)
|
||||
self.splits = {k: {} for k in self.split_names}
|
||||
self.shuffle = shuffle
|
||||
|
||||
@@ -104,8 +102,8 @@ class PinaDataModule(LightningDataModule):
|
||||
self.split_length,
|
||||
shuffle=self.shuffle)
|
||||
for i in range(len(self.split_length)):
|
||||
self.splits[
|
||||
self.split_names[i]][dataset.data_type] = splits[i]
|
||||
self.splits[self.split_names[i]][
|
||||
dataset.data_type] = splits[i]
|
||||
elif stage == 'test':
|
||||
raise NotImplementedError("Testing pipeline not implemented yet")
|
||||
else:
|
||||
@@ -137,14 +135,12 @@ class PinaDataModule(LightningDataModule):
|
||||
if seed is not None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
indices = torch.randperm(sum(lengths),
|
||||
generator=generator)
|
||||
indices = torch.randperm(sum(lengths), generator=generator)
|
||||
else:
|
||||
indices = torch.randperm(sum(lengths))
|
||||
dataset.apply_shuffle(indices)
|
||||
|
||||
indices = torch.arange(0, sum(lengths), 1,
|
||||
dtype=torch.uint8).tolist()
|
||||
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
|
||||
offsets = [
|
||||
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
|
||||
]
|
||||
@@ -161,13 +157,16 @@ class PinaDataModule(LightningDataModule):
|
||||
collector = self.problem.collector
|
||||
batching_dim = self.problem.batching_dimension
|
||||
datasets_slots = [i.__slots__ for i in self.dataset_classes]
|
||||
self.datasets = [dataset(device=self.device) for dataset in
|
||||
self.dataset_classes]
|
||||
self.datasets = [
|
||||
dataset(device=self.device) for dataset in self.dataset_classes
|
||||
]
|
||||
logging.debug('Filling datasets in PinaDataModule obj')
|
||||
for name, data in collector.data_collections.items():
|
||||
keys = list(data.keys())
|
||||
idx = [key for key, val in collector.conditions_name.items() if
|
||||
val == name]
|
||||
idx = [
|
||||
key for key, val in collector.conditions_name.items()
|
||||
if val == name
|
||||
]
|
||||
for i, slot in enumerate(datasets_slots):
|
||||
if slot == keys:
|
||||
self.datasets[i].add_points(data, idx[0], batching_dim)
|
||||
|
||||
@@ -37,10 +37,7 @@ class Batch:
|
||||
if item in super().__getattribute__('attributes'):
|
||||
dataset = super().__getattribute__(item)
|
||||
index = super().__getattribute__(item + '_idx')
|
||||
return PinaSubset(
|
||||
dataset.dataset,
|
||||
dataset.indices[index])
|
||||
else:
|
||||
return PinaSubset(dataset.dataset, dataset.indices[index])
|
||||
return super().__getattribute__(item)
|
||||
|
||||
def __getattr__(self, item):
|
||||
|
||||
@@ -19,15 +19,17 @@ class SamplePointDataset(BaseDataset):
|
||||
data_dict.pop('equation')
|
||||
super().add_points(data_dict, condition_idx)
|
||||
|
||||
def _init_from_problem(self, collector_dict, batching_dim=0):
|
||||
def _init_from_problem(self, collector_dict):
|
||||
for name, data in collector_dict.items():
|
||||
keys = list(data.keys())
|
||||
if set(self.__slots__) == set(keys):
|
||||
data = deepcopy(data)
|
||||
data.pop('equation')
|
||||
self._populate_init_list(data)
|
||||
idx = [key for key, val in
|
||||
self.problem.collector.conditions_name.items() if
|
||||
val == name]
|
||||
idx = [
|
||||
key for key, val in
|
||||
self.problem.collector.conditions_name.items()
|
||||
if val == name
|
||||
]
|
||||
self.conditions_idx.append(idx)
|
||||
self.initialize()
|
||||
|
||||
@@ -168,9 +168,8 @@ class CartesianDomain(DomainInterface):
|
||||
for variable in variables:
|
||||
if variable in self.fixed_.keys():
|
||||
value = self.fixed_[variable]
|
||||
pts_variable = torch.tensor([[value]]).repeat(
|
||||
result.shape[0], 1
|
||||
)
|
||||
pts_variable = torch.tensor([[value]
|
||||
]).repeat(result.shape[0], 1)
|
||||
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||
pts_variable.labels = [variable]
|
||||
|
||||
@@ -203,9 +202,8 @@ class CartesianDomain(DomainInterface):
|
||||
for variable in variables:
|
||||
if variable in self.fixed_.keys():
|
||||
value = self.fixed_[variable]
|
||||
pts_variable = torch.tensor([[value]]).repeat(
|
||||
result.shape[0], 1
|
||||
)
|
||||
pts_variable = torch.tensor([[value]
|
||||
]).repeat(result.shape[0], 1)
|
||||
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||
pts_variable.labels = [variable]
|
||||
|
||||
|
||||
@@ -38,8 +38,7 @@ class DomainInterface(metaclass=ABCMeta):
|
||||
if value not in DomainInterface.available_sampling_modes:
|
||||
raise TypeError(f"mode {value} not valid. Expected at least "
|
||||
"one in "
|
||||
f"{DomainInterface.available_sampling_modes}."
|
||||
)
|
||||
f"{DomainInterface.available_sampling_modes}.")
|
||||
|
||||
@abstractmethod
|
||||
def sample(self):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
""" Module for LabelTensor """
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@@ -10,9 +10,8 @@ def issubset(a, b):
|
||||
"""
|
||||
if isinstance(a, list) and isinstance(b, list):
|
||||
return set(a).issubset(set(b))
|
||||
elif isinstance(a, range) and isinstance(b, range):
|
||||
if isinstance(a, range) and isinstance(b, range):
|
||||
return a.start <= b.start and a.stop >= b.stop
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@@ -20,7 +19,7 @@ class LabelTensor(torch.Tensor):
|
||||
"""Torch tensor with a label for any column."""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, x, labels, full=True, *args, **kwargs):
|
||||
def __new__(cls, x, labels, *args, **kwargs):
|
||||
if isinstance(x, LabelTensor):
|
||||
return x
|
||||
else:
|
||||
@@ -30,7 +29,7 @@ class LabelTensor(torch.Tensor):
|
||||
def tensor(self):
|
||||
return self.as_subclass(Tensor)
|
||||
|
||||
def __init__(self, x, labels, full=False):
|
||||
def __init__(self, x, labels, **kwargs):
|
||||
"""
|
||||
Construct a `LabelTensor` by passing a dict of the labels
|
||||
|
||||
@@ -42,14 +41,19 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
"""
|
||||
self.dim_names = None
|
||||
self.full = full
|
||||
self.full = kwargs.get('full', True)
|
||||
self.labels = labels
|
||||
|
||||
@classmethod
|
||||
def __internal_init__(cls, x, labels, dim_names ,full=False, *args, **kwargs):
|
||||
lt = cls.__new__(cls, x, labels, full, *args, **kwargs)
|
||||
def __internal_init__(cls,
|
||||
x,
|
||||
labels,
|
||||
dim_names,
|
||||
*args,
|
||||
**kwargs):
|
||||
lt = cls.__new__(cls, x, labels, *args, **kwargs)
|
||||
lt._labels = labels
|
||||
lt.full = full
|
||||
lt.full = kwargs.get('full', True)
|
||||
lt.dim_names = dim_names
|
||||
return lt
|
||||
|
||||
@@ -122,8 +126,12 @@ class LabelTensor(torch.Tensor):
|
||||
tensor_shape = self.shape
|
||||
|
||||
if hasattr(self, 'full') and self.full:
|
||||
labels = {i: labels[i] if i in labels else {'name': i} for i in
|
||||
labels.keys()}
|
||||
labels = {
|
||||
i: labels[i] if i in labels else {
|
||||
'name': i
|
||||
}
|
||||
for i in labels.keys()
|
||||
}
|
||||
for k, v in labels.items():
|
||||
# Init labels from str
|
||||
if isinstance(v, str):
|
||||
@@ -133,8 +141,8 @@ class LabelTensor(torch.Tensor):
|
||||
# Init from dict with only name key
|
||||
v['dof'] = range(tensor_shape[k])
|
||||
# Init from dict with both name and dof keys
|
||||
elif isinstance(v, dict) and sorted(list(v.keys())) == ['dof',
|
||||
'name']:
|
||||
elif isinstance(v, dict) and sorted(list(
|
||||
v.keys())) == ['dof', 'name']:
|
||||
dof_list = v['dof']
|
||||
dof_len = len(dof_list)
|
||||
if dof_len != len(set(dof_list)):
|
||||
@@ -143,7 +151,7 @@ class LabelTensor(torch.Tensor):
|
||||
raise ValueError(
|
||||
'Number of dof does not match tensor shape')
|
||||
else:
|
||||
ValueError('Illegal labels initialization')
|
||||
raise ValueError('Illegal labels initialization')
|
||||
# Perform update
|
||||
self._labels[k] = v
|
||||
|
||||
@@ -157,7 +165,11 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
# Create a dict with labels
|
||||
last_dim_labels = {
|
||||
self.ndim - 1: {'dof': labels, 'name': self.ndim - 1}}
|
||||
self.ndim - 1: {
|
||||
'dof': labels,
|
||||
'name': self.ndim - 1
|
||||
}
|
||||
}
|
||||
self._init_labels_from_dict(last_dim_labels)
|
||||
|
||||
def set_names(self):
|
||||
@@ -217,8 +229,9 @@ class LabelTensor(torch.Tensor):
|
||||
v = [v] if isinstance(v, (int, str)) else v
|
||||
|
||||
if not isinstance(v, range):
|
||||
extractor[idx_dim] = [dim_labels.index(i) for i in v] if len(
|
||||
v) > 1 else slice(dim_labels.index(v[0]),
|
||||
extractor[idx_dim] = [dim_labels.index(i)
|
||||
for i in v] if len(v) > 1 else slice(
|
||||
dim_labels.index(v[0]),
|
||||
dim_labels.index(v[0]) + 1)
|
||||
else:
|
||||
extractor[idx_dim] = slice(v.start, v.stop)
|
||||
@@ -263,10 +276,10 @@ class LabelTensor(torch.Tensor):
|
||||
new_tensor = torch.cat(tensors, dim=dim)
|
||||
|
||||
# Update labels
|
||||
labels = LabelTensor.__create_labels_cat(tensors,
|
||||
dim)
|
||||
labels = LabelTensor.__create_labels_cat(tensors, dim)
|
||||
|
||||
return LabelTensor.__internal_init__(new_tensor, labels, tensors[0].dim_names)
|
||||
return LabelTensor.__internal_init__(new_tensor, labels,
|
||||
tensors[0].dim_names)
|
||||
|
||||
@staticmethod
|
||||
def __create_labels_cat(tensors, dim):
|
||||
@@ -277,7 +290,8 @@ class LabelTensor(torch.Tensor):
|
||||
# check if:
|
||||
# - labels dict have same keys
|
||||
# - all labels are the same expect for dimension dim
|
||||
if not all(all(stored_labels[i][k] == stored_labels[0][k]
|
||||
if not all(
|
||||
all(stored_labels[i][k] == stored_labels[0][k]
|
||||
for i in range(len(stored_labels)))
|
||||
for k in stored_labels[0].keys() if k != dim):
|
||||
raise RuntimeError('tensors must have the same shape and dof')
|
||||
@@ -341,8 +355,12 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)]
|
||||
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
|
||||
labels.update({tensors[0].ndim - 1: {'dof': last_dim_labels,
|
||||
'name': tensors[0].name}})
|
||||
labels.update({
|
||||
tensors[0].ndim - 1: {
|
||||
'dof': last_dim_labels,
|
||||
'name': tensors[0].name
|
||||
}
|
||||
})
|
||||
return LabelTensor(data, labels)
|
||||
|
||||
def append(self, tensor, mode='std'):
|
||||
@@ -384,8 +402,9 @@ class LabelTensor(torch.Tensor):
|
||||
:param index:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(
|
||||
isinstance(a, str) for a in index)):
|
||||
if isinstance(index,
|
||||
str) or (isinstance(index, (tuple, list))
|
||||
and all(isinstance(a, str) for a in index)):
|
||||
return self.extract(index)
|
||||
|
||||
selected_lt = super().__getitem__(index)
|
||||
@@ -418,21 +437,31 @@ class LabelTensor(torch.Tensor):
|
||||
:return:
|
||||
"""
|
||||
old_dof = old_labels[dim]['dof']
|
||||
if not isinstance(index, (int, slice)) and len(index) == len(
|
||||
old_dof) and isinstance(old_dof, range):
|
||||
if not isinstance(
|
||||
index,
|
||||
(int, slice)) and len(index) == len(old_dof) and isinstance(
|
||||
old_dof, range):
|
||||
return
|
||||
if isinstance(index, torch.Tensor):
|
||||
index = index.nonzero(as_tuple=True)[
|
||||
0] if index.dtype == torch.bool else index.tolist()
|
||||
index = index.nonzero(
|
||||
as_tuple=True
|
||||
)[0] if index.dtype == torch.bool else index.tolist()
|
||||
if isinstance(index, list):
|
||||
to_update_labels.update({dim: {
|
||||
to_update_labels.update({
|
||||
dim: {
|
||||
'dof': [old_dof[i] for i in index],
|
||||
'name': old_labels[dim]['name']}})
|
||||
'name': old_labels[dim]['name']
|
||||
}
|
||||
})
|
||||
else:
|
||||
to_update_labels.update({dim: {'dof': old_dof[index],
|
||||
'name': old_labels[dim]['name']}})
|
||||
to_update_labels.update(
|
||||
{dim: {
|
||||
'dof': old_dof[index],
|
||||
'name': old_labels[dim]['name']
|
||||
}})
|
||||
|
||||
def sort_labels(self, dim=None):
|
||||
|
||||
def arg_sort(lst):
|
||||
return sorted(range(len(lst)), key=lambda x: lst[x])
|
||||
|
||||
@@ -445,7 +474,6 @@ class LabelTensor(torch.Tensor):
|
||||
return self.__getitem__(indexer)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
from copy import deepcopy
|
||||
cls = self.__class__
|
||||
result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
|
||||
return result
|
||||
@@ -454,6 +482,8 @@ class LabelTensor(torch.Tensor):
|
||||
tensor = super().permute(*dims)
|
||||
stored_labels = self.stored_labels
|
||||
keys_list = list(*dims)
|
||||
labels = {keys_list.index(k): copy(stored_labels[k]) for k in
|
||||
stored_labels.keys()}
|
||||
labels = {
|
||||
keys_list.index(k): copy(stored_labels[k])
|
||||
for k in stored_labels.keys()
|
||||
}
|
||||
return LabelTensor.__internal_init__(tensor, labels, self.dim_names)
|
||||
|
||||
@@ -56,9 +56,9 @@ def grad(output_, input_, components=None, d=None):
|
||||
gradients = torch.autograd.grad(
|
||||
output_,
|
||||
input_,
|
||||
grad_outputs=torch.ones(
|
||||
output_.size(), dtype=output_.dtype, device=output_.device
|
||||
),
|
||||
grad_outputs=torch.ones(output_.size(),
|
||||
dtype=output_.dtype,
|
||||
device=output_.device),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
allow_unused=True,
|
||||
@@ -85,8 +85,8 @@ def grad(output_, input_, components=None, d=None):
|
||||
raise RuntimeError
|
||||
gradients = grad_scalar_output(output_, input_, d)
|
||||
|
||||
elif output_.shape[
|
||||
output_.ndim - 1] >= 2: # vector output ##############################
|
||||
elif output_.shape[output_.ndim -
|
||||
1] >= 2: # vector output ##############################
|
||||
tensor_to_cat = []
|
||||
for i, c in enumerate(components):
|
||||
c_output = output_.extract([c])
|
||||
@@ -281,11 +281,8 @@ def advection(output_, input_, velocity_field, components=None, d=None):
|
||||
if components is None:
|
||||
components = output_.labels
|
||||
|
||||
tmp = (
|
||||
grad(output_, input_, components, d)
|
||||
.reshape(-1, len(components), len(d))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
tmp = (grad(output_, input_, components, d).reshape(-1, len(components),
|
||||
len(d)).transpose(0, 1))
|
||||
|
||||
tmp *= output_.extract(velocity_field)
|
||||
return tmp.sum(dim=2).T
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..condition.domain_equation_condition import DomainEquationCondition
|
||||
from ..collector import Collector
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
The abstract `AbstractProblem` class. All the class defining a PINA Problem
|
||||
@@ -116,9 +117,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
return self._conditions
|
||||
|
||||
def discretise_domain(
|
||||
self, n, mode="random", variables="all", locations="all"
|
||||
):
|
||||
def discretise_domain(self,
|
||||
n,
|
||||
mode="random",
|
||||
variables="all",
|
||||
locations="all"):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
@@ -170,9 +173,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
# check correct location
|
||||
if locations == "all":
|
||||
locations = [name for name in self.conditions.keys()
|
||||
if isinstance(self.conditions[name],
|
||||
DomainEquationCondition)]
|
||||
locations = [
|
||||
name for name in self.conditions.keys()
|
||||
if isinstance(self.conditions[name], DomainEquationCondition)
|
||||
]
|
||||
else:
|
||||
if not isinstance(locations, (list)):
|
||||
locations = [locations]
|
||||
|
||||
@@ -142,5 +142,4 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
condition.condition_type):
|
||||
raise ValueError(
|
||||
f'{self.__name__} support only dose not support condition '
|
||||
f'{condition.condition_type}'
|
||||
)
|
||||
f'{condition.condition_type}')
|
||||
|
||||
@@ -9,8 +9,13 @@ from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2,
|
||||
eval_size=.1, **kwargs):
|
||||
def __init__(self,
|
||||
solver,
|
||||
batch_size=None,
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
eval_size=.1,
|
||||
**kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
@@ -48,8 +53,7 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
pb.unknown_parameters[key].data.to(device))
|
||||
|
||||
def _create_loader(self):
|
||||
"""
|
||||
@@ -58,14 +62,11 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join(
|
||||
[
|
||||
error_message = '\n'.join([
|
||||
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
||||
"not sampled"}"""
|
||||
for key, value in
|
||||
"not sampled"}""" for key, value in
|
||||
self._solver.problem.collector._is_conditions_ready.items()
|
||||
]
|
||||
)
|
||||
])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
@@ -76,7 +77,8 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
device = devices[0]
|
||||
|
||||
data_module = PinaDataModule(problem=self.solver.problem, device=device,
|
||||
data_module = PinaDataModule(problem=self.solver.problem,
|
||||
device=device,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.eval_size)
|
||||
@@ -87,9 +89,9 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
"""
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(
|
||||
self.solver, train_dataloaders=self._loader, **kwargs
|
||||
)
|
||||
return super().fit(self.solver,
|
||||
train_dataloaders=self._loader,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
|
||||
@@ -40,9 +40,9 @@ def check_consistency(object, object_instance, subclass=False):
|
||||
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
|
||||
|
||||
|
||||
def number_parameters(
|
||||
model, aggregate=True, only_trainable=True
|
||||
): # TODO: check
|
||||
def number_parameters(model,
|
||||
aggregate=True,
|
||||
only_trainable=True): # TODO: check
|
||||
"""
|
||||
Return the number of parameters of a given `model`.
|
||||
|
||||
@@ -80,9 +80,8 @@ def merge_two_tensors(tensor1, tensor2):
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
|
||||
)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
return tensor1.append(tensor2)
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,11 @@ def test_init_inputoutput():
|
||||
Condition(input_points=3., output_points='example')
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=example_domain, output_points=example_domain)
|
||||
|
||||
|
||||
test_init_inputoutput()
|
||||
|
||||
|
||||
def test_init_domainfunc():
|
||||
Condition(domain=example_domain, equation=FixedValue(0.0))
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@@ -201,11 +201,12 @@ data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
|
||||
|
||||
class GraphProblem(AbstractProblem):
|
||||
output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
|
||||
input = [Graph.build('radius',
|
||||
input = [
|
||||
Graph.build('radius',
|
||||
nodes_coordinates=coordinates[i, :, :],
|
||||
nodes_data=data[i, :, :], radius=0.2)
|
||||
for i in
|
||||
range(100)]
|
||||
nodes_data=data[i, :, :],
|
||||
radius=0.2) for i in range(100)
|
||||
]
|
||||
output_variables = ['u']
|
||||
|
||||
conditions = {
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
from pina import LabelTensor
|
||||
from pina.domain import CartesianDomain
|
||||
|
||||
|
||||
def test_constructor():
|
||||
CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user