* Adding Collector for handling data sampling/collection before dataset/dataloader
* Modify domain by adding sample_mode, variables as property * Small change concatenate -> cat in lno/avno * Create different factory classes for conditions
This commit is contained in:
committed by
Nicola Demo
parent
aef5a5d590
commit
1bd3f40f54
72
pina/collector.py
Normal file
72
pina/collector.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from .utils import check_consistency, merge_tensors
|
||||
|
||||
class Collector:
|
||||
def __init__(self, problem):
|
||||
self.problem = problem # hook Collector <-> Problem
|
||||
self.data_collections = {name : {} for name in self.problem.conditions} # collection of data
|
||||
self.is_conditions_ready = {
|
||||
name : False for name in self.problem.conditions} # names of the conditions that need to be sampled
|
||||
self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig
|
||||
|
||||
@property
|
||||
def full(self):
|
||||
return all(self.is_conditions_ready.values())
|
||||
|
||||
@full.setter
|
||||
def full(self, value):
|
||||
check_consistency(value, bool)
|
||||
self._full = value
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
self._problem = value
|
||||
|
||||
def store_fixed_data(self):
|
||||
# loop over all conditions
|
||||
for condition_name, condition in self.problem.conditions.items():
|
||||
# if the condition is not ready and domain is not attribute
|
||||
# of condition, we get and store the data
|
||||
if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
||||
# get data
|
||||
keys = condition.__slots__
|
||||
values = [getattr(condition, name) for name in keys]
|
||||
self.data_collections[condition_name] = dict(zip(keys, values))
|
||||
# condition now is ready
|
||||
self.is_conditions_ready[condition_name] = True
|
||||
|
||||
def store_sample_domains(self, n, mode, variables, sample_locations):
|
||||
# loop over all locations
|
||||
for loc in sample_locations:
|
||||
# get condition
|
||||
condition = self.problem.conditions[loc]
|
||||
keys = ["input_points", "equation"]
|
||||
# if the condition is not ready, we get and store the data
|
||||
if (not self.is_conditions_ready[loc]):
|
||||
# if it is the first time we sample
|
||||
if not self.data_collections[loc]:
|
||||
already_sampled = []
|
||||
# if we have sampled the condition but not all variables
|
||||
else:
|
||||
already_sampled = [self.data_collections[loc].input_points]
|
||||
# if the condition is ready but we want to sample again
|
||||
else:
|
||||
self.is_conditions_ready[loc] = False
|
||||
already_sampled = []
|
||||
|
||||
# get the samples
|
||||
samples = [
|
||||
condition.domain.sample(n=n, mode=mode, variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
if (
|
||||
sorted(self.data_collections[loc].input_points.labels)
|
||||
==
|
||||
sorted(self.problem.input_variables)
|
||||
):
|
||||
self.is_conditions_ready[loc] = True
|
||||
values = [pts, condition.equation]
|
||||
self.data_collections[loc] = dict(zip(keys, values))
|
||||
@@ -39,11 +39,10 @@ class Condition:
|
||||
|
||||
__slots__ = list(
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__,
|
||||
InputPointsEquationCondition.__slots__,
|
||||
DomainEquationCondition.__slots__,
|
||||
InputOutputPointsCondition.__slots__ +
|
||||
InputPointsEquationCondition.__slots__ +
|
||||
DomainEquationCondition.__slots__ +
|
||||
DataConditionInterface.__slots__
|
||||
|
||||
)
|
||||
)
|
||||
|
||||
@@ -51,8 +50,8 @@ class Condition:
|
||||
|
||||
if len(args) != 0:
|
||||
raise ValueError(
|
||||
f"Condition takes only the following keyword '
|
||||
'arguments: {Condition.__slots__}."
|
||||
"Condition takes only the following keyword "
|
||||
f"arguments: {Condition.__slots__}."
|
||||
)
|
||||
|
||||
sorted_keys = sorted(kwargs.keys())
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
|
||||
from abc import ABCMeta
|
||||
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
|
||||
condition_types = ['physical', 'supervised', 'unsupervised']
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, *args, **wargs):
|
||||
self._condition_type = None
|
||||
self._problem = None
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
self._problem = value
|
||||
|
||||
@property
|
||||
def condition_type(self):
|
||||
return self._condition_type
|
||||
|
||||
@condition_type.setattr
|
||||
@condition_type.setter
|
||||
def condition_type(self, values):
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values]
|
||||
|
||||
@@ -24,21 +24,7 @@ class DataConditionInterface(ConditionInterface):
|
||||
self.conditionalvariable = conditionalvariable
|
||||
self.condition_type = 'unsupervised'
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@data.setter
|
||||
def data(self, value):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._data = value
|
||||
|
||||
@property
|
||||
def conditionalvariable(self):
|
||||
return self._conditionalvariable
|
||||
|
||||
@data.setter
|
||||
def conditionalvariable(self, value):
|
||||
if value is not None:
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'data') or (key == 'conditionalvariable'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._data = value
|
||||
DataConditionInterface.__dict__[key].__set__(self, value)
|
||||
@@ -1,8 +1,6 @@
|
||||
import torch
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface
|
||||
from ..equation.equation_interface import EquationInterface
|
||||
@@ -24,20 +22,10 @@ class DomainEquationCondition(ConditionInterface):
|
||||
self.equation = equation
|
||||
self.condition_type = 'physics'
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
return self._domain
|
||||
|
||||
@domain.setter
|
||||
def domain(self, value):
|
||||
check_consistency(value, (DomainInterface))
|
||||
self._domain = value
|
||||
|
||||
@property
|
||||
def equation(self):
|
||||
return self._equation
|
||||
|
||||
@equation.setter
|
||||
def equation(self, value):
|
||||
check_consistency(value, (EquationInterface))
|
||||
self._equation = value
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'domain':
|
||||
check_consistency(value, (DomainInterface))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
@@ -23,20 +23,10 @@ class InputPointsEquationCondition(ConditionInterface):
|
||||
self.equation = equation
|
||||
self.condition_type = 'physics'
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
return self._input_points
|
||||
|
||||
@input_points.setter
|
||||
def input_points(self, value):
|
||||
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
||||
self._input_points = value
|
||||
|
||||
@property
|
||||
def equation(self):
|
||||
return self._equation
|
||||
|
||||
@equation.setter
|
||||
def equation(self, value):
|
||||
check_consistency(value, (EquationInterface))
|
||||
self._equation = value
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'input_points':
|
||||
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
@@ -23,20 +23,7 @@ class InputOutputPointsCondition(ConditionInterface):
|
||||
self.output_points = output_points
|
||||
self.condition_type = ['supervised', 'physics']
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
return self._input_points
|
||||
|
||||
@input_points.setter
|
||||
def input_points(self, value):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._input_points = value
|
||||
|
||||
@property
|
||||
def output_points(self):
|
||||
return self._output_points
|
||||
|
||||
@output_points.setter
|
||||
def output_points(self, value):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
self._output_points = value
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'input_points') or (key == 'output_points'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
||||
@@ -21,7 +21,6 @@ class CartesianDomain(DomainInterface):
|
||||
"""
|
||||
self.fixed_ = {}
|
||||
self.range_ = {}
|
||||
self.sample_modes = ["random", "grid", "lh", "chebyshev", "latin"]
|
||||
|
||||
for k, v in cartesian_dict.items():
|
||||
if isinstance(v, (int, float)):
|
||||
@@ -31,6 +30,10 @@ class CartesianDomain(DomainInterface):
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
@property
|
||||
def sample_modes(self):
|
||||
return ["random", "grid", "lh", "chebyshev", "latin"]
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
"""Spatial variables.
|
||||
|
||||
@@ -9,7 +9,7 @@ class DomainInterface(metaclass=ABCMeta):
|
||||
Any geometry entity should inherit from this class.
|
||||
"""
|
||||
|
||||
__available_sampling_modes = ["random", "grid", "lh", "chebyshev", "latin"]
|
||||
available_sampling_modes = ["random", "grid", "lh", "chebyshev", "latin"]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -19,6 +19,14 @@ class DomainInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def variables(self):
|
||||
"""
|
||||
Abstract method returing Domain variables.
|
||||
"""
|
||||
pass
|
||||
|
||||
@sample_modes.setter
|
||||
def sample_modes(self, values):
|
||||
"""
|
||||
@@ -27,10 +35,10 @@ class DomainInterface(metaclass=ABCMeta):
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values]
|
||||
for value in values:
|
||||
if value not in DomainInterface.__available_sampling_modes:
|
||||
if value not in DomainInterface.available_sampling_modes:
|
||||
raise TypeError(f"mode {value} not valid. Expected at least "
|
||||
"one in "
|
||||
f"{DomainInterface.__available_sampling_modes}."
|
||||
f"{DomainInterface.available_sampling_modes}."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -39,7 +39,6 @@ class EllipsoidDomain(DomainInterface):
|
||||
self.range_ = {}
|
||||
self._centers = None
|
||||
self._axis = None
|
||||
self.sample_modes = "random"
|
||||
|
||||
# checking consistency
|
||||
check_consistency(sample_surface, bool)
|
||||
@@ -72,6 +71,10 @@ class EllipsoidDomain(DomainInterface):
|
||||
self._centers = dict(zip(self.range_.keys(), centers.tolist()))
|
||||
self._axis = dict(zip(self.range_.keys(), ellipsoid_axis.tolist()))
|
||||
|
||||
@property
|
||||
def sample_modes(self):
|
||||
return ["random"]
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
"""Spatial variables.
|
||||
|
||||
@@ -24,8 +24,9 @@ class OperationInterface(DomainInterface, metaclass=ABCMeta):
|
||||
# assign geometries
|
||||
self._geometries = geometries
|
||||
|
||||
# sampling mode, for now random is the only available
|
||||
self.sample_modes = "random"
|
||||
@property
|
||||
def sample_modes(self):
|
||||
return ["random"]
|
||||
|
||||
@property
|
||||
def geometries(self):
|
||||
|
||||
@@ -74,9 +74,10 @@ class SimplexDomain(DomainInterface):
|
||||
# build cartesian_bound
|
||||
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
||||
|
||||
# sampling mode
|
||||
self.sample_modes = "random"
|
||||
|
||||
@property
|
||||
def sample_modes(self):
|
||||
return ["random"]
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return sorted(self._vertices_matrix.labels)
|
||||
|
||||
@@ -32,9 +32,16 @@ class Union(OperationInterface):
|
||||
|
||||
"""
|
||||
super().__init__(geometries)
|
||||
|
||||
@property
|
||||
def sample_modes(self):
|
||||
self.sample_modes = list(
|
||||
set([geom.sample_modes for geom in geometries])
|
||||
set([geom.sample_modes for geom in self.geometries])
|
||||
)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return list(set([geom.variables for geom in self.geometries]))
|
||||
|
||||
def is_inside(self, point, check_border=False):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module Averaging Neural Operator."""
|
||||
|
||||
import torch
|
||||
from torch import nn, concatenate
|
||||
from torch import nn, cat
|
||||
from .layers import AVNOBlock
|
||||
from .base_no import KernelNeuralOperator
|
||||
from pina.utils import check_consistency
|
||||
@@ -110,9 +110,9 @@ class AveragingNeuralOperator(KernelNeuralOperator):
|
||||
"""
|
||||
points_tmp = x.extract(self.coordinates_indices)
|
||||
new_batch = x.extract(self.field_indices)
|
||||
new_batch = concatenate((new_batch, points_tmp), dim=-1)
|
||||
new_batch = cat((new_batch, points_tmp), dim=-1)
|
||||
new_batch = self._lifting_operator(new_batch)
|
||||
new_batch = self._integral_kernels(new_batch)
|
||||
new_batch = concatenate((new_batch, points_tmp), dim=-1)
|
||||
new_batch = cat((new_batch, points_tmp), dim=-1)
|
||||
new_batch = self._projection_operator(new_batch)
|
||||
return new_batch
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module LowRank Neural Operator."""
|
||||
|
||||
import torch
|
||||
from torch import nn, concatenate
|
||||
from torch import nn, cat
|
||||
|
||||
from pina.utils import check_consistency
|
||||
|
||||
@@ -145,4 +145,4 @@ class LowRankNeuralOperator(KernelNeuralOperator):
|
||||
for module in self._integral_kernels:
|
||||
x = module(x, coords)
|
||||
# projecting
|
||||
return self._projection_operator(concatenate((x, coords), dim=-1))
|
||||
return self._projection_operator(cat((x, coords), dim=-1))
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
""" Module for AbstractProblem class """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..utils import merge_tensors, check_consistency
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface
|
||||
from ..condition.domain_equation_condition import DomainEquationCondition
|
||||
from ..collector import Collector
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
from .. import LabelTensor
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
@@ -20,27 +19,25 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self._discretized_domains = {}
|
||||
|
||||
for name, domain in self.domains.items():
|
||||
if isinstance(domain, (torch.Tensor, LabelTensor)):
|
||||
self._discretized_domains[name] = domain
|
||||
# create collector to manage problem data
|
||||
self.collector = Collector(self)
|
||||
|
||||
# create hook conditions <-> problems
|
||||
for condition_name in self.conditions:
|
||||
self.conditions[condition_name].set_problem(self)
|
||||
self.conditions[condition_name].problem = self
|
||||
|
||||
# # variable storing all points
|
||||
self.input_pts = {}
|
||||
# store in collector all the available fixed points
|
||||
# note that some points could not be stored at this stage (e.g. when
|
||||
# sampling locations). To check that all data points are ready for
|
||||
# training all type self.collector.full, which returns true if all
|
||||
# points are ready.
|
||||
self.collector.store_fixed_data()
|
||||
|
||||
# # varible to check if sampling is done. If no location
|
||||
# # element is presented in Condition this variable is set to true
|
||||
# self._have_sampled_points = {}
|
||||
for condition_name in self.conditions:
|
||||
self._discretized_domains[condition_name] = False
|
||||
|
||||
# # put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
|
||||
@property
|
||||
def input_pts(self):
|
||||
return self.collector.data_collections
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
Implements deepcopy for the
|
||||
@@ -85,19 +82,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
def input_variables(self, variables):
|
||||
raise RuntimeError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def domains(self):
|
||||
"""
|
||||
The domain(s) where the conditions of the AbstractProblem are valid.
|
||||
If more than one domain type is passed, a list of Location is
|
||||
retured.
|
||||
|
||||
:return: the domain(s) of ``self``
|
||||
:rtype: list[Location]
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_variables(self):
|
||||
@@ -114,34 +98,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
return self._conditions
|
||||
|
||||
|
||||
|
||||
def _span_condition_points(self):
|
||||
"""
|
||||
Simple function to get the condition points
|
||||
"""
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
if hasattr(condition, "input_points"):
|
||||
samples = condition.input_points
|
||||
self.input_pts[condition_name] = samples
|
||||
self._discretized_domains[condition_name] = True
|
||||
if hasattr(self, "unknown_parameter_domain"):
|
||||
# initialize the unknown parameters of the inverse problem given
|
||||
# the domain the user gives
|
||||
self.unknown_parameters = {}
|
||||
for i, var in enumerate(self.unknown_variables):
|
||||
range_var = self.unknown_parameter_domain.range_[var]
|
||||
tensor_var = (
|
||||
torch.rand(1, requires_grad=True) * range_var[1]
|
||||
+ range_var[0]
|
||||
)
|
||||
self.unknown_parameters[var] = torch.nn.Parameter(
|
||||
tensor_var
|
||||
)
|
||||
|
||||
def discretise_domain(
|
||||
self, n, mode="random", variables="all", domains="all"
|
||||
self, n, mode="random", variables="all", locations="all"
|
||||
):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
@@ -172,119 +130,38 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
``CartesianDomain``.
|
||||
"""
|
||||
|
||||
# check consistecy n
|
||||
# check consistecy n, mode, variables, locations
|
||||
check_consistency(n, int)
|
||||
|
||||
# check consistency mode
|
||||
check_consistency(mode, str)
|
||||
if mode not in ["random", "grid", "lh", "chebyshev", "latin"]:
|
||||
check_consistency(variables, str)
|
||||
check_consistency(locations, str)
|
||||
|
||||
# check correct sampling mode
|
||||
if mode not in DomainInterface.available_sampling_modes:
|
||||
raise TypeError(f"mode {mode} not valid.")
|
||||
|
||||
# check consistency variables
|
||||
# check correct variables
|
||||
if variables == "all":
|
||||
variables = self.input_variables
|
||||
else:
|
||||
check_consistency(variables, str)
|
||||
|
||||
if sorted(variables) != sorted(self.input_variables):
|
||||
TypeError(
|
||||
f"Wrong variables for sampling. Variables ",
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
|
||||
# # check consistency location # TODO: check if this is needed (from 0.1)
|
||||
# locations_to_sample = [
|
||||
# condition
|
||||
# for condition in self.conditions
|
||||
# if hasattr(self.conditions[condition], "location")
|
||||
# ]
|
||||
# if locations == "all":
|
||||
# # only locations that can be sampled
|
||||
# locations = locations_to_sample
|
||||
# else:
|
||||
# check_consistency(locations, str)
|
||||
|
||||
# if sorted(locations) != sorted(locations_to_sample):
|
||||
if domains == "all":
|
||||
domains = [condition for condition in self.conditions]
|
||||
else:
|
||||
check_consistency(domains, str)
|
||||
print(domains)
|
||||
if sorted(domains) != sorted(self.conditions):
|
||||
TypeError(
|
||||
f"Wrong locations for sampling. Location ",
|
||||
f"should be in {locations_to_sample}.",
|
||||
)
|
||||
|
||||
# sampling
|
||||
for d in domains:
|
||||
condition = self.conditions[d]
|
||||
|
||||
# we try to check if we have already sampled
|
||||
try:
|
||||
already_sampled = [self.input_pts[d]]
|
||||
# if we have not sampled, a key error is thrown
|
||||
except KeyError:
|
||||
already_sampled = []
|
||||
|
||||
# if we have already sampled fully the condition
|
||||
# but we want to sample again we set already_sampled
|
||||
# to an empty list since we need to sample again, and
|
||||
# self._have_sampled_points to False.
|
||||
if self._discretized_domains[d]:
|
||||
already_sampled = []
|
||||
self._discretized_domains[d] = False
|
||||
print(condition.domain)
|
||||
print(d)
|
||||
# build samples
|
||||
samples = [
|
||||
self.domains[d].sample(n=n, mode=mode, variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
self.input_pts[d] = pts
|
||||
|
||||
# the condition is sampled if input_pts contains all labels
|
||||
if sorted(self.input_pts[d].labels) == sorted(
|
||||
self.input_variables
|
||||
):
|
||||
# self._have_sampled_points[location] = True
|
||||
# self.input_pts[location] = self.input_pts[location].extract(
|
||||
# sorted(self.input_variables)
|
||||
# )
|
||||
self._have_sampled_points[d] = True
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
Adding points to the already sampled points.
|
||||
|
||||
:param dict new_points: a dictionary with key the location to add the points
|
||||
and values the torch.Tensor points.
|
||||
"""
|
||||
|
||||
if sorted(new_points.keys()) != sorted(self.conditions):
|
||||
TypeError(
|
||||
f"Wrong locations for new points. Location ",
|
||||
f"should be in {self.conditions}.",
|
||||
)
|
||||
|
||||
for location in new_points.keys():
|
||||
# extract old and new points
|
||||
old_pts = self.input_pts[location]
|
||||
new_pts = new_points[location]
|
||||
|
||||
# if they don't have the same variables error
|
||||
if sorted(old_pts.labels) != sorted(new_pts.labels):
|
||||
for variable in variables:
|
||||
if variable not in self.input_variables:
|
||||
TypeError(
|
||||
f"Not matching variables for old and new points "
|
||||
f"in condition {location}."
|
||||
f"Wrong variables for sampling. Variables ",
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
if old_pts.labels != new_pts.labels:
|
||||
new_pts = torch.hstack(
|
||||
[new_pts.extract([i]) for i in old_pts.labels]
|
||||
)
|
||||
new_pts.labels = old_pts.labels
|
||||
|
||||
# merging
|
||||
merged_pts = torch.vstack([old_pts, new_pts])
|
||||
merged_pts.labels = old_pts.labels
|
||||
self.input_pts[location] = merged_pts
|
||||
# check correct location
|
||||
if locations == "all":
|
||||
locations = [name for name in self.conditions.keys()]
|
||||
else:
|
||||
if not isinstance(locations, (list)):
|
||||
locations = [locations]
|
||||
for loc in locations:
|
||||
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
raise TypeError(
|
||||
f"Wrong locations passed, locations for sampling "
|
||||
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
)
|
||||
|
||||
# store data
|
||||
self.collector.store_sample_domains(n, mode, variables, locations)
|
||||
|
||||
@@ -45,6 +45,20 @@ class InverseProblem(AbstractProblem):
|
||||
>>> 'data': Condition(CartesianDomain({'x': [0, 1]}), Equation(solution_data))
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# storing unknown_parameters for optimization
|
||||
self.unknown_parameters = {}
|
||||
for i, var in enumerate(self.unknown_variables):
|
||||
range_var = self.unknown_parameter_domain.range_[var]
|
||||
tensor_var = (
|
||||
torch.rand(1, requires_grad=True) * range_var[1]
|
||||
+ range_var[0]
|
||||
)
|
||||
self.unknown_parameters[var] = torch.nn.Parameter(
|
||||
tensor_var
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def unknown_parameter_domain(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user