* Adding Collector for handling data sampling/collection before dataset/dataloader
* Modify domain by adding sample_mode, variables as property * Small change concatenate -> cat in lno/avno * Create different factory classes for conditions
This commit is contained in:
committed by
Nicola Demo
parent
aef5a5d590
commit
1bd3f40f54
72
pina/collector.py
Normal file
72
pina/collector.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from .utils import check_consistency, merge_tensors
|
||||||
|
|
||||||
|
class Collector:
|
||||||
|
def __init__(self, problem):
|
||||||
|
self.problem = problem # hook Collector <-> Problem
|
||||||
|
self.data_collections = {name : {} for name in self.problem.conditions} # collection of data
|
||||||
|
self.is_conditions_ready = {
|
||||||
|
name : False for name in self.problem.conditions} # names of the conditions that need to be sampled
|
||||||
|
self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full(self):
|
||||||
|
return all(self.is_conditions_ready.values())
|
||||||
|
|
||||||
|
@full.setter
|
||||||
|
def full(self, value):
|
||||||
|
check_consistency(value, bool)
|
||||||
|
self._full = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def problem(self):
|
||||||
|
return self._problem
|
||||||
|
|
||||||
|
@problem.setter
|
||||||
|
def problem(self, value):
|
||||||
|
self._problem = value
|
||||||
|
|
||||||
|
def store_fixed_data(self):
|
||||||
|
# loop over all conditions
|
||||||
|
for condition_name, condition in self.problem.conditions.items():
|
||||||
|
# if the condition is not ready and domain is not attribute
|
||||||
|
# of condition, we get and store the data
|
||||||
|
if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
||||||
|
# get data
|
||||||
|
keys = condition.__slots__
|
||||||
|
values = [getattr(condition, name) for name in keys]
|
||||||
|
self.data_collections[condition_name] = dict(zip(keys, values))
|
||||||
|
# condition now is ready
|
||||||
|
self.is_conditions_ready[condition_name] = True
|
||||||
|
|
||||||
|
def store_sample_domains(self, n, mode, variables, sample_locations):
|
||||||
|
# loop over all locations
|
||||||
|
for loc in sample_locations:
|
||||||
|
# get condition
|
||||||
|
condition = self.problem.conditions[loc]
|
||||||
|
keys = ["input_points", "equation"]
|
||||||
|
# if the condition is not ready, we get and store the data
|
||||||
|
if (not self.is_conditions_ready[loc]):
|
||||||
|
# if it is the first time we sample
|
||||||
|
if not self.data_collections[loc]:
|
||||||
|
already_sampled = []
|
||||||
|
# if we have sampled the condition but not all variables
|
||||||
|
else:
|
||||||
|
already_sampled = [self.data_collections[loc].input_points]
|
||||||
|
# if the condition is ready but we want to sample again
|
||||||
|
else:
|
||||||
|
self.is_conditions_ready[loc] = False
|
||||||
|
already_sampled = []
|
||||||
|
|
||||||
|
# get the samples
|
||||||
|
samples = [
|
||||||
|
condition.domain.sample(n=n, mode=mode, variables=variables)
|
||||||
|
] + already_sampled
|
||||||
|
pts = merge_tensors(samples)
|
||||||
|
if (
|
||||||
|
sorted(self.data_collections[loc].input_points.labels)
|
||||||
|
==
|
||||||
|
sorted(self.problem.input_variables)
|
||||||
|
):
|
||||||
|
self.is_conditions_ready[loc] = True
|
||||||
|
values = [pts, condition.equation]
|
||||||
|
self.data_collections[loc] = dict(zip(keys, values))
|
||||||
@@ -39,11 +39,10 @@ class Condition:
|
|||||||
|
|
||||||
__slots__ = list(
|
__slots__ = list(
|
||||||
set(
|
set(
|
||||||
InputOutputPointsCondition.__slots__,
|
InputOutputPointsCondition.__slots__ +
|
||||||
InputPointsEquationCondition.__slots__,
|
InputPointsEquationCondition.__slots__ +
|
||||||
DomainEquationCondition.__slots__,
|
DomainEquationCondition.__slots__ +
|
||||||
DataConditionInterface.__slots__
|
DataConditionInterface.__slots__
|
||||||
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,8 +50,8 @@ class Condition:
|
|||||||
|
|
||||||
if len(args) != 0:
|
if len(args) != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Condition takes only the following keyword '
|
"Condition takes only the following keyword "
|
||||||
'arguments: {Condition.__slots__}."
|
f"arguments: {Condition.__slots__}."
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_keys = sorted(kwargs.keys())
|
sorted_keys = sorted(kwargs.keys())
|
||||||
|
|||||||
@@ -1,18 +1,27 @@
|
|||||||
|
|
||||||
from abc import ABCMeta
|
from abc import ABCMeta
|
||||||
|
|
||||||
|
|
||||||
class ConditionInterface(metaclass=ABCMeta):
|
class ConditionInterface(metaclass=ABCMeta):
|
||||||
|
|
||||||
condition_types = ['physical', 'supervised', 'unsupervised']
|
condition_types = ['physical', 'supervised', 'unsupervised']
|
||||||
def __init__(self):
|
|
||||||
|
def __init__(self, *args, **wargs):
|
||||||
self._condition_type = None
|
self._condition_type = None
|
||||||
|
self._problem = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def problem(self):
|
||||||
|
return self._problem
|
||||||
|
|
||||||
|
@problem.setter
|
||||||
|
def problem(self, value):
|
||||||
|
self._problem = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def condition_type(self):
|
def condition_type(self):
|
||||||
return self._condition_type
|
return self._condition_type
|
||||||
|
|
||||||
@condition_type.setattr
|
@condition_type.setter
|
||||||
def condition_type(self, values):
|
def condition_type(self, values):
|
||||||
if not isinstance(values, (list, tuple)):
|
if not isinstance(values, (list, tuple)):
|
||||||
values = [values]
|
values = [values]
|
||||||
|
|||||||
@@ -24,21 +24,7 @@ class DataConditionInterface(ConditionInterface):
|
|||||||
self.conditionalvariable = conditionalvariable
|
self.conditionalvariable = conditionalvariable
|
||||||
self.condition_type = 'unsupervised'
|
self.condition_type = 'unsupervised'
|
||||||
|
|
||||||
@property
|
def __setattr__(self, key, value):
|
||||||
def data(self):
|
if (key == 'data') or (key == 'conditionalvariable'):
|
||||||
return self._data
|
|
||||||
|
|
||||||
@data.setter
|
|
||||||
def data(self, value):
|
|
||||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
|
||||||
self._data = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def conditionalvariable(self):
|
|
||||||
return self._conditionalvariable
|
|
||||||
|
|
||||||
@data.setter
|
|
||||||
def conditionalvariable(self, value):
|
|
||||||
if value is not None:
|
|
||||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
self._data = value
|
DataConditionInterface.__dict__[key].__set__(self, value)
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .condition_interface import ConditionInterface
|
from .condition_interface import ConditionInterface
|
||||||
from ..label_tensor import LabelTensor
|
|
||||||
from ..graph import Graph
|
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from ..domain import DomainInterface
|
from ..domain import DomainInterface
|
||||||
from ..equation.equation_interface import EquationInterface
|
from ..equation.equation_interface import EquationInterface
|
||||||
@@ -24,20 +22,10 @@ class DomainEquationCondition(ConditionInterface):
|
|||||||
self.equation = equation
|
self.equation = equation
|
||||||
self.condition_type = 'physics'
|
self.condition_type = 'physics'
|
||||||
|
|
||||||
@property
|
def __setattr__(self, key, value):
|
||||||
def domain(self):
|
if key == 'domain':
|
||||||
return self._domain
|
check_consistency(value, (DomainInterface))
|
||||||
|
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||||
@domain.setter
|
elif key == 'equation':
|
||||||
def domain(self, value):
|
check_consistency(value, (EquationInterface))
|
||||||
check_consistency(value, (DomainInterface))
|
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||||
self._domain = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def equation(self):
|
|
||||||
return self._equation
|
|
||||||
|
|
||||||
@equation.setter
|
|
||||||
def equation(self, value):
|
|
||||||
check_consistency(value, (EquationInterface))
|
|
||||||
self._equation = value
|
|
||||||
@@ -23,20 +23,10 @@ class InputPointsEquationCondition(ConditionInterface):
|
|||||||
self.equation = equation
|
self.equation = equation
|
||||||
self.condition_type = 'physics'
|
self.condition_type = 'physics'
|
||||||
|
|
||||||
@property
|
def __setattr__(self, key, value):
|
||||||
def input_points(self):
|
if key == 'input_points':
|
||||||
return self._input_points
|
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
||||||
|
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||||
@input_points.setter
|
elif key == 'equation':
|
||||||
def input_points(self, value):
|
check_consistency(value, (EquationInterface))
|
||||||
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||||
self._input_points = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def equation(self):
|
|
||||||
return self._equation
|
|
||||||
|
|
||||||
@equation.setter
|
|
||||||
def equation(self, value):
|
|
||||||
check_consistency(value, (EquationInterface))
|
|
||||||
self._equation = value
|
|
||||||
@@ -23,20 +23,7 @@ class InputOutputPointsCondition(ConditionInterface):
|
|||||||
self.output_points = output_points
|
self.output_points = output_points
|
||||||
self.condition_type = ['supervised', 'physics']
|
self.condition_type = ['supervised', 'physics']
|
||||||
|
|
||||||
@property
|
def __setattr__(self, key, value):
|
||||||
def input_points(self):
|
if (key == 'input_points') or (key == 'output_points'):
|
||||||
return self._input_points
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
|
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
||||||
@input_points.setter
|
|
||||||
def input_points(self, value):
|
|
||||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
|
||||||
self._input_points = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_points(self):
|
|
||||||
return self._output_points
|
|
||||||
|
|
||||||
@output_points.setter
|
|
||||||
def output_points(self, value):
|
|
||||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
|
||||||
self._output_points = value
|
|
||||||
@@ -21,7 +21,6 @@ class CartesianDomain(DomainInterface):
|
|||||||
"""
|
"""
|
||||||
self.fixed_ = {}
|
self.fixed_ = {}
|
||||||
self.range_ = {}
|
self.range_ = {}
|
||||||
self.sample_modes = ["random", "grid", "lh", "chebyshev", "latin"]
|
|
||||||
|
|
||||||
for k, v in cartesian_dict.items():
|
for k, v in cartesian_dict.items():
|
||||||
if isinstance(v, (int, float)):
|
if isinstance(v, (int, float)):
|
||||||
@@ -31,6 +30,10 @@ class CartesianDomain(DomainInterface):
|
|||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_modes(self):
|
||||||
|
return ["random", "grid", "lh", "chebyshev", "latin"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
"""Spatial variables.
|
"""Spatial variables.
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class DomainInterface(metaclass=ABCMeta):
|
|||||||
Any geometry entity should inherit from this class.
|
Any geometry entity should inherit from this class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__available_sampling_modes = ["random", "grid", "lh", "chebyshev", "latin"]
|
available_sampling_modes = ["random", "grid", "lh", "chebyshev", "latin"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -19,6 +19,14 @@ class DomainInterface(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def variables(self):
|
||||||
|
"""
|
||||||
|
Abstract method returing Domain variables.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@sample_modes.setter
|
@sample_modes.setter
|
||||||
def sample_modes(self, values):
|
def sample_modes(self, values):
|
||||||
"""
|
"""
|
||||||
@@ -27,10 +35,10 @@ class DomainInterface(metaclass=ABCMeta):
|
|||||||
if not isinstance(values, (list, tuple)):
|
if not isinstance(values, (list, tuple)):
|
||||||
values = [values]
|
values = [values]
|
||||||
for value in values:
|
for value in values:
|
||||||
if value not in DomainInterface.__available_sampling_modes:
|
if value not in DomainInterface.available_sampling_modes:
|
||||||
raise TypeError(f"mode {value} not valid. Expected at least "
|
raise TypeError(f"mode {value} not valid. Expected at least "
|
||||||
"one in "
|
"one in "
|
||||||
f"{DomainInterface.__available_sampling_modes}."
|
f"{DomainInterface.available_sampling_modes}."
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ class EllipsoidDomain(DomainInterface):
|
|||||||
self.range_ = {}
|
self.range_ = {}
|
||||||
self._centers = None
|
self._centers = None
|
||||||
self._axis = None
|
self._axis = None
|
||||||
self.sample_modes = "random"
|
|
||||||
|
|
||||||
# checking consistency
|
# checking consistency
|
||||||
check_consistency(sample_surface, bool)
|
check_consistency(sample_surface, bool)
|
||||||
@@ -72,6 +71,10 @@ class EllipsoidDomain(DomainInterface):
|
|||||||
self._centers = dict(zip(self.range_.keys(), centers.tolist()))
|
self._centers = dict(zip(self.range_.keys(), centers.tolist()))
|
||||||
self._axis = dict(zip(self.range_.keys(), ellipsoid_axis.tolist()))
|
self._axis = dict(zip(self.range_.keys(), ellipsoid_axis.tolist()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_modes(self):
|
||||||
|
return ["random"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
"""Spatial variables.
|
"""Spatial variables.
|
||||||
|
|||||||
@@ -24,8 +24,9 @@ class OperationInterface(DomainInterface, metaclass=ABCMeta):
|
|||||||
# assign geometries
|
# assign geometries
|
||||||
self._geometries = geometries
|
self._geometries = geometries
|
||||||
|
|
||||||
# sampling mode, for now random is the only available
|
@property
|
||||||
self.sample_modes = "random"
|
def sample_modes(self):
|
||||||
|
return ["random"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def geometries(self):
|
def geometries(self):
|
||||||
|
|||||||
@@ -74,9 +74,10 @@ class SimplexDomain(DomainInterface):
|
|||||||
# build cartesian_bound
|
# build cartesian_bound
|
||||||
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
||||||
|
|
||||||
# sampling mode
|
@property
|
||||||
self.sample_modes = "random"
|
def sample_modes(self):
|
||||||
|
return ["random"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
return sorted(self._vertices_matrix.labels)
|
return sorted(self._vertices_matrix.labels)
|
||||||
|
|||||||
@@ -32,9 +32,16 @@ class Union(OperationInterface):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(geometries)
|
super().__init__(geometries)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_modes(self):
|
||||||
self.sample_modes = list(
|
self.sample_modes = list(
|
||||||
set([geom.sample_modes for geom in geometries])
|
set([geom.sample_modes for geom in self.geometries])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variables(self):
|
||||||
|
return list(set([geom.variables for geom in self.geometries]))
|
||||||
|
|
||||||
def is_inside(self, point, check_border=False):
|
def is_inside(self, point, check_border=False):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module Averaging Neural Operator."""
|
"""Module Averaging Neural Operator."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, concatenate
|
from torch import nn, cat
|
||||||
from .layers import AVNOBlock
|
from .layers import AVNOBlock
|
||||||
from .base_no import KernelNeuralOperator
|
from .base_no import KernelNeuralOperator
|
||||||
from pina.utils import check_consistency
|
from pina.utils import check_consistency
|
||||||
@@ -110,9 +110,9 @@ class AveragingNeuralOperator(KernelNeuralOperator):
|
|||||||
"""
|
"""
|
||||||
points_tmp = x.extract(self.coordinates_indices)
|
points_tmp = x.extract(self.coordinates_indices)
|
||||||
new_batch = x.extract(self.field_indices)
|
new_batch = x.extract(self.field_indices)
|
||||||
new_batch = concatenate((new_batch, points_tmp), dim=-1)
|
new_batch = cat((new_batch, points_tmp), dim=-1)
|
||||||
new_batch = self._lifting_operator(new_batch)
|
new_batch = self._lifting_operator(new_batch)
|
||||||
new_batch = self._integral_kernels(new_batch)
|
new_batch = self._integral_kernels(new_batch)
|
||||||
new_batch = concatenate((new_batch, points_tmp), dim=-1)
|
new_batch = cat((new_batch, points_tmp), dim=-1)
|
||||||
new_batch = self._projection_operator(new_batch)
|
new_batch = self._projection_operator(new_batch)
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module LowRank Neural Operator."""
|
"""Module LowRank Neural Operator."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, concatenate
|
from torch import nn, cat
|
||||||
|
|
||||||
from pina.utils import check_consistency
|
from pina.utils import check_consistency
|
||||||
|
|
||||||
@@ -145,4 +145,4 @@ class LowRankNeuralOperator(KernelNeuralOperator):
|
|||||||
for module in self._integral_kernels:
|
for module in self._integral_kernels:
|
||||||
x = module(x, coords)
|
x = module(x, coords)
|
||||||
# projecting
|
# projecting
|
||||||
return self._projection_operator(concatenate((x, coords), dim=-1))
|
return self._projection_operator(cat((x, coords), dim=-1))
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
""" Module for AbstractProblem class """
|
""" Module for AbstractProblem class """
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from ..utils import merge_tensors, check_consistency
|
from ..utils import check_consistency
|
||||||
|
from ..domain import DomainInterface
|
||||||
|
from ..condition.domain_equation_condition import DomainEquationCondition
|
||||||
|
from ..collector import Collector
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
|
||||||
|
|
||||||
from .. import LabelTensor
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractProblem(metaclass=ABCMeta):
|
class AbstractProblem(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
@@ -20,27 +19,25 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
self._discretized_domains = {}
|
# create collector to manage problem data
|
||||||
|
self.collector = Collector(self)
|
||||||
for name, domain in self.domains.items():
|
|
||||||
if isinstance(domain, (torch.Tensor, LabelTensor)):
|
|
||||||
self._discretized_domains[name] = domain
|
|
||||||
|
|
||||||
|
# create hook conditions <-> problems
|
||||||
for condition_name in self.conditions:
|
for condition_name in self.conditions:
|
||||||
self.conditions[condition_name].set_problem(self)
|
self.conditions[condition_name].problem = self
|
||||||
|
|
||||||
# # variable storing all points
|
# store in collector all the available fixed points
|
||||||
self.input_pts = {}
|
# note that some points could not be stored at this stage (e.g. when
|
||||||
|
# sampling locations). To check that all data points are ready for
|
||||||
|
# training all type self.collector.full, which returns true if all
|
||||||
|
# points are ready.
|
||||||
|
self.collector.store_fixed_data()
|
||||||
|
|
||||||
# # varible to check if sampling is done. If no location
|
|
||||||
# # element is presented in Condition this variable is set to true
|
|
||||||
# self._have_sampled_points = {}
|
|
||||||
for condition_name in self.conditions:
|
|
||||||
self._discretized_domains[condition_name] = False
|
|
||||||
|
|
||||||
# # put in self.input_pts all the points that we don't need to sample
|
|
||||||
self._span_condition_points()
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_pts(self):
|
||||||
|
return self.collector.data_collections
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
"""
|
"""
|
||||||
Implements deepcopy for the
|
Implements deepcopy for the
|
||||||
@@ -85,19 +82,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
def input_variables(self, variables):
|
def input_variables(self, variables):
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def domains(self):
|
|
||||||
"""
|
|
||||||
The domain(s) where the conditions of the AbstractProblem are valid.
|
|
||||||
If more than one domain type is passed, a list of Location is
|
|
||||||
retured.
|
|
||||||
|
|
||||||
:return: the domain(s) of ``self``
|
|
||||||
:rtype: list[Location]
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def output_variables(self):
|
def output_variables(self):
|
||||||
@@ -114,34 +98,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
return self._conditions
|
return self._conditions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _span_condition_points(self):
|
|
||||||
"""
|
|
||||||
Simple function to get the condition points
|
|
||||||
"""
|
|
||||||
for condition_name in self.conditions:
|
|
||||||
condition = self.conditions[condition_name]
|
|
||||||
if hasattr(condition, "input_points"):
|
|
||||||
samples = condition.input_points
|
|
||||||
self.input_pts[condition_name] = samples
|
|
||||||
self._discretized_domains[condition_name] = True
|
|
||||||
if hasattr(self, "unknown_parameter_domain"):
|
|
||||||
# initialize the unknown parameters of the inverse problem given
|
|
||||||
# the domain the user gives
|
|
||||||
self.unknown_parameters = {}
|
|
||||||
for i, var in enumerate(self.unknown_variables):
|
|
||||||
range_var = self.unknown_parameter_domain.range_[var]
|
|
||||||
tensor_var = (
|
|
||||||
torch.rand(1, requires_grad=True) * range_var[1]
|
|
||||||
+ range_var[0]
|
|
||||||
)
|
|
||||||
self.unknown_parameters[var] = torch.nn.Parameter(
|
|
||||||
tensor_var
|
|
||||||
)
|
|
||||||
|
|
||||||
def discretise_domain(
|
def discretise_domain(
|
||||||
self, n, mode="random", variables="all", domains="all"
|
self, n, mode="random", variables="all", locations="all"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate a set of points to span the `Location` of all the conditions of
|
Generate a set of points to span the `Location` of all the conditions of
|
||||||
@@ -172,119 +130,38 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
``CartesianDomain``.
|
``CartesianDomain``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# check consistecy n
|
# check consistecy n, mode, variables, locations
|
||||||
check_consistency(n, int)
|
check_consistency(n, int)
|
||||||
|
|
||||||
# check consistency mode
|
|
||||||
check_consistency(mode, str)
|
check_consistency(mode, str)
|
||||||
if mode not in ["random", "grid", "lh", "chebyshev", "latin"]:
|
check_consistency(variables, str)
|
||||||
|
check_consistency(locations, str)
|
||||||
|
|
||||||
|
# check correct sampling mode
|
||||||
|
if mode not in DomainInterface.available_sampling_modes:
|
||||||
raise TypeError(f"mode {mode} not valid.")
|
raise TypeError(f"mode {mode} not valid.")
|
||||||
|
|
||||||
# check consistency variables
|
# check correct variables
|
||||||
if variables == "all":
|
if variables == "all":
|
||||||
variables = self.input_variables
|
variables = self.input_variables
|
||||||
else:
|
for variable in variables:
|
||||||
check_consistency(variables, str)
|
if variable not in self.input_variables:
|
||||||
|
|
||||||
if sorted(variables) != sorted(self.input_variables):
|
|
||||||
TypeError(
|
|
||||||
f"Wrong variables for sampling. Variables ",
|
|
||||||
f"should be in {self.input_variables}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# # check consistency location # TODO: check if this is needed (from 0.1)
|
|
||||||
# locations_to_sample = [
|
|
||||||
# condition
|
|
||||||
# for condition in self.conditions
|
|
||||||
# if hasattr(self.conditions[condition], "location")
|
|
||||||
# ]
|
|
||||||
# if locations == "all":
|
|
||||||
# # only locations that can be sampled
|
|
||||||
# locations = locations_to_sample
|
|
||||||
# else:
|
|
||||||
# check_consistency(locations, str)
|
|
||||||
|
|
||||||
# if sorted(locations) != sorted(locations_to_sample):
|
|
||||||
if domains == "all":
|
|
||||||
domains = [condition for condition in self.conditions]
|
|
||||||
else:
|
|
||||||
check_consistency(domains, str)
|
|
||||||
print(domains)
|
|
||||||
if sorted(domains) != sorted(self.conditions):
|
|
||||||
TypeError(
|
|
||||||
f"Wrong locations for sampling. Location ",
|
|
||||||
f"should be in {locations_to_sample}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# sampling
|
|
||||||
for d in domains:
|
|
||||||
condition = self.conditions[d]
|
|
||||||
|
|
||||||
# we try to check if we have already sampled
|
|
||||||
try:
|
|
||||||
already_sampled = [self.input_pts[d]]
|
|
||||||
# if we have not sampled, a key error is thrown
|
|
||||||
except KeyError:
|
|
||||||
already_sampled = []
|
|
||||||
|
|
||||||
# if we have already sampled fully the condition
|
|
||||||
# but we want to sample again we set already_sampled
|
|
||||||
# to an empty list since we need to sample again, and
|
|
||||||
# self._have_sampled_points to False.
|
|
||||||
if self._discretized_domains[d]:
|
|
||||||
already_sampled = []
|
|
||||||
self._discretized_domains[d] = False
|
|
||||||
print(condition.domain)
|
|
||||||
print(d)
|
|
||||||
# build samples
|
|
||||||
samples = [
|
|
||||||
self.domains[d].sample(n=n, mode=mode, variables=variables)
|
|
||||||
] + already_sampled
|
|
||||||
pts = merge_tensors(samples)
|
|
||||||
self.input_pts[d] = pts
|
|
||||||
|
|
||||||
# the condition is sampled if input_pts contains all labels
|
|
||||||
if sorted(self.input_pts[d].labels) == sorted(
|
|
||||||
self.input_variables
|
|
||||||
):
|
|
||||||
# self._have_sampled_points[location] = True
|
|
||||||
# self.input_pts[location] = self.input_pts[location].extract(
|
|
||||||
# sorted(self.input_variables)
|
|
||||||
# )
|
|
||||||
self._have_sampled_points[d] = True
|
|
||||||
|
|
||||||
def add_points(self, new_points):
|
|
||||||
"""
|
|
||||||
Adding points to the already sampled points.
|
|
||||||
|
|
||||||
:param dict new_points: a dictionary with key the location to add the points
|
|
||||||
and values the torch.Tensor points.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if sorted(new_points.keys()) != sorted(self.conditions):
|
|
||||||
TypeError(
|
|
||||||
f"Wrong locations for new points. Location ",
|
|
||||||
f"should be in {self.conditions}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
for location in new_points.keys():
|
|
||||||
# extract old and new points
|
|
||||||
old_pts = self.input_pts[location]
|
|
||||||
new_pts = new_points[location]
|
|
||||||
|
|
||||||
# if they don't have the same variables error
|
|
||||||
if sorted(old_pts.labels) != sorted(new_pts.labels):
|
|
||||||
TypeError(
|
TypeError(
|
||||||
f"Not matching variables for old and new points "
|
f"Wrong variables for sampling. Variables ",
|
||||||
f"in condition {location}."
|
f"should be in {self.input_variables}.",
|
||||||
)
|
)
|
||||||
if old_pts.labels != new_pts.labels:
|
|
||||||
new_pts = torch.hstack(
|
|
||||||
[new_pts.extract([i]) for i in old_pts.labels]
|
|
||||||
)
|
|
||||||
new_pts.labels = old_pts.labels
|
|
||||||
|
|
||||||
# merging
|
# check correct location
|
||||||
merged_pts = torch.vstack([old_pts, new_pts])
|
if locations == "all":
|
||||||
merged_pts.labels = old_pts.labels
|
locations = [name for name in self.conditions.keys()]
|
||||||
self.input_pts[location] = merged_pts
|
else:
|
||||||
|
if not isinstance(locations, (list)):
|
||||||
|
locations = [locations]
|
||||||
|
for loc in locations:
|
||||||
|
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||||
|
raise TypeError(
|
||||||
|
f"Wrong locations passed, locations for sampling "
|
||||||
|
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# store data
|
||||||
|
self.collector.store_sample_domains(n, mode, variables, locations)
|
||||||
|
|||||||
@@ -45,6 +45,20 @@ class InverseProblem(AbstractProblem):
|
|||||||
>>> 'data': Condition(CartesianDomain({'x': [0, 1]}), Equation(solution_data))
|
>>> 'data': Condition(CartesianDomain({'x': [0, 1]}), Equation(solution_data))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# storing unknown_parameters for optimization
|
||||||
|
self.unknown_parameters = {}
|
||||||
|
for i, var in enumerate(self.unknown_variables):
|
||||||
|
range_var = self.unknown_parameter_domain.range_[var]
|
||||||
|
tensor_var = (
|
||||||
|
torch.rand(1, requires_grad=True) * range_var[1]
|
||||||
|
+ range_var[0]
|
||||||
|
)
|
||||||
|
self.unknown_parameters[var] = torch.nn.Parameter(
|
||||||
|
tensor_var
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def unknown_parameter_domain(self):
|
def unknown_parameter_domain(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -27,31 +27,31 @@ class Poisson(SpatialProblem):
|
|||||||
|
|
||||||
conditions = {
|
conditions = {
|
||||||
'gamma1':
|
'gamma1':
|
||||||
Condition(location=CartesianDomain({
|
Condition(domain=CartesianDomain({
|
||||||
'x': [0, 1],
|
'x': [0, 1],
|
||||||
'y': 1
|
'y': 1
|
||||||
}),
|
}),
|
||||||
equation=FixedValue(0.0)),
|
equation=FixedValue(0.0)),
|
||||||
'gamma2':
|
'gamma2':
|
||||||
Condition(location=CartesianDomain({
|
Condition(domain=CartesianDomain({
|
||||||
'x': [0, 1],
|
'x': [0, 1],
|
||||||
'y': 0
|
'y': 0
|
||||||
}),
|
}),
|
||||||
equation=FixedValue(0.0)),
|
equation=FixedValue(0.0)),
|
||||||
'gamma3':
|
'gamma3':
|
||||||
Condition(location=CartesianDomain({
|
Condition(domain=CartesianDomain({
|
||||||
'x': 1,
|
'x': 1,
|
||||||
'y': [0, 1]
|
'y': [0, 1]
|
||||||
}),
|
}),
|
||||||
equation=FixedValue(0.0)),
|
equation=FixedValue(0.0)),
|
||||||
'gamma4':
|
'gamma4':
|
||||||
Condition(location=CartesianDomain({
|
Condition(domain=CartesianDomain({
|
||||||
'x': 0,
|
'x': 0,
|
||||||
'y': [0, 1]
|
'y': [0, 1]
|
||||||
}),
|
}),
|
||||||
equation=FixedValue(0.0)),
|
equation=FixedValue(0.0)),
|
||||||
'D':
|
'D':
|
||||||
Condition(location=CartesianDomain({
|
Condition(domain=CartesianDomain({
|
||||||
'x': [0, 1],
|
'x': [0, 1],
|
||||||
'y': [0, 1]
|
'y': [0, 1]
|
||||||
}),
|
}),
|
||||||
@@ -67,6 +67,10 @@ class Poisson(SpatialProblem):
|
|||||||
truth_solution = poisson_sol
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
|
|
||||||
|
# make the problem
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
print(poisson_problem.input_pts)
|
||||||
|
|
||||||
def test_discretise_domain():
|
def test_discretise_domain():
|
||||||
n = 10
|
n = 10
|
||||||
poisson_problem = Poisson()
|
poisson_problem = Poisson()
|
||||||
@@ -90,15 +94,14 @@ def test_discretise_domain():
|
|||||||
assert poisson_problem.input_pts['D'].shape[0] == n
|
assert poisson_problem.input_pts['D'].shape[0] == n
|
||||||
|
|
||||||
|
|
||||||
def test_sampling_few_variables():
|
# def test_sampling_few_variables():
|
||||||
n = 10
|
# n = 10
|
||||||
poisson_problem = Poisson()
|
# poisson_problem.discretise_domain(n,
|
||||||
poisson_problem.discretise_domain(n,
|
# 'grid',
|
||||||
'grid',
|
# locations=['D'],
|
||||||
locations=['D'],
|
# variables=['x'])
|
||||||
variables=['x'])
|
# assert poisson_problem.input_pts['D'].shape[1] == 1
|
||||||
assert poisson_problem.input_pts['D'].shape[1] == 1
|
# assert poisson_problem._have_sampled_points['D'] is False
|
||||||
assert poisson_problem._have_sampled_points['D'] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_variables_correct_order_sampling():
|
def test_variables_correct_order_sampling():
|
||||||
|
|||||||
Reference in New Issue
Block a user