* Adding Collector for handling data sampling/collection before dataset/dataloader

* Modify domain by adding sample_mode, variables as property
* Small change concatenate -> cat in lno/avno
* Create different factory classes for conditions
This commit is contained in:
Dario Coscia
2024-10-04 13:57:18 +02:00
committed by Nicola Demo
parent aef5a5d590
commit 1bd3f40f54
18 changed files with 225 additions and 277 deletions

72
pina/collector.py Normal file
View File

@@ -0,0 +1,72 @@
from .utils import check_consistency, merge_tensors
class Collector:
def __init__(self, problem):
self.problem = problem # hook Collector <-> Problem
self.data_collections = {name : {} for name in self.problem.conditions} # collection of data
self.is_conditions_ready = {
name : False for name in self.problem.conditions} # names of the conditions that need to be sampled
self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig
@property
def full(self):
return all(self.is_conditions_ready.values())
@full.setter
def full(self, value):
check_consistency(value, bool)
self._full = value
@property
def problem(self):
return self._problem
@problem.setter
def problem(self, value):
self._problem = value
def store_fixed_data(self):
# loop over all conditions
for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready
self.is_conditions_ready[condition_name] = True
def store_sample_domains(self, n, mode, variables, sample_locations):
# loop over all locations
for loc in sample_locations:
# get condition
condition = self.problem.conditions[loc]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self.is_conditions_ready[loc]):
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
# if we have sampled the condition but not all variables
else:
already_sampled = [self.data_collections[loc].input_points]
# if the condition is ready but we want to sample again
else:
self.is_conditions_ready[loc] = False
already_sampled = []
# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (
sorted(self.data_collections[loc].input_points.labels)
==
sorted(self.problem.input_variables)
):
self.is_conditions_ready[loc] = True
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))

View File

@@ -39,11 +39,10 @@ class Condition:
__slots__ = list(
set(
InputOutputPointsCondition.__slots__,
InputPointsEquationCondition.__slots__,
DomainEquationCondition.__slots__,
InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__
)
)
@@ -51,8 +50,8 @@ class Condition:
if len(args) != 0:
raise ValueError(
f"Condition takes only the following keyword '
'arguments: {Condition.__slots__}."
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
)
sorted_keys = sorted(kwargs.keys())

View File

@@ -1,18 +1,27 @@
from abc import ABCMeta
class ConditionInterface(metaclass=ABCMeta):
condition_types = ['physical', 'supervised', 'unsupervised']
def __init__(self):
def __init__(self, *args, **wargs):
self._condition_type = None
self._problem = None
@property
def problem(self):
return self._problem
@problem.setter
def problem(self, value):
self._problem = value
@property
def condition_type(self):
return self._condition_type
@condition_type.setattr
@condition_type.setter
def condition_type(self, values):
if not isinstance(values, (list, tuple)):
values = [values]

View File

@@ -24,21 +24,7 @@ class DataConditionInterface(ConditionInterface):
self.conditionalvariable = conditionalvariable
self.condition_type = 'unsupervised'
@property
def data(self):
return self._data
@data.setter
def data(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._data = value
@property
def conditionalvariable(self):
return self._conditionalvariable
@data.setter
def conditionalvariable(self, value):
if value is not None:
def __setattr__(self, key, value):
if (key == 'data') or (key == 'conditionalvariable'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._data = value
DataConditionInterface.__dict__[key].__set__(self, value)

View File

@@ -1,8 +1,6 @@
import torch
from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency
from ..domain import DomainInterface
from ..equation.equation_interface import EquationInterface
@@ -24,20 +22,10 @@ class DomainEquationCondition(ConditionInterface):
self.equation = equation
self.condition_type = 'physics'
@property
def domain(self):
return self._domain
@domain.setter
def domain(self, value):
check_consistency(value, (DomainInterface))
self._domain = value
@property
def equation(self):
return self._equation
@equation.setter
def equation(self, value):
check_consistency(value, (EquationInterface))
self._equation = value
def __setattr__(self, key, value):
if key == 'domain':
check_consistency(value, (DomainInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)

View File

@@ -23,20 +23,10 @@ class InputPointsEquationCondition(ConditionInterface):
self.equation = equation
self.condition_type = 'physics'
@property
def input_points(self):
return self._input_points
@input_points.setter
def input_points(self, value):
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
self._input_points = value
@property
def equation(self):
return self._equation
@equation.setter
def equation(self, value):
check_consistency(value, (EquationInterface))
self._equation = value
def __setattr__(self, key, value):
if key == 'input_points':
check_consistency(value, (LabelTensor)) # for now only labeltensors, we need labels for the operators!
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)

View File

@@ -23,20 +23,7 @@ class InputOutputPointsCondition(ConditionInterface):
self.output_points = output_points
self.condition_type = ['supervised', 'physics']
@property
def input_points(self):
return self._input_points
@input_points.setter
def input_points(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._input_points = value
@property
def output_points(self):
return self._output_points
@output_points.setter
def output_points(self, value):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
self._output_points = value
def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value)

View File

@@ -21,7 +21,6 @@ class CartesianDomain(DomainInterface):
"""
self.fixed_ = {}
self.range_ = {}
self.sample_modes = ["random", "grid", "lh", "chebyshev", "latin"]
for k, v in cartesian_dict.items():
if isinstance(v, (int, float)):
@@ -31,6 +30,10 @@ class CartesianDomain(DomainInterface):
else:
raise TypeError
@property
def sample_modes(self):
return ["random", "grid", "lh", "chebyshev", "latin"]
@property
def variables(self):
"""Spatial variables.

View File

@@ -9,7 +9,7 @@ class DomainInterface(metaclass=ABCMeta):
Any geometry entity should inherit from this class.
"""
__available_sampling_modes = ["random", "grid", "lh", "chebyshev", "latin"]
available_sampling_modes = ["random", "grid", "lh", "chebyshev", "latin"]
@property
@abstractmethod
@@ -19,6 +19,14 @@ class DomainInterface(metaclass=ABCMeta):
"""
pass
@property
@abstractmethod
def variables(self):
"""
Abstract method returing Domain variables.
"""
pass
@sample_modes.setter
def sample_modes(self, values):
"""
@@ -27,10 +35,10 @@ class DomainInterface(metaclass=ABCMeta):
if not isinstance(values, (list, tuple)):
values = [values]
for value in values:
if value not in DomainInterface.__available_sampling_modes:
if value not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {value} not valid. Expected at least "
"one in "
f"{DomainInterface.__available_sampling_modes}."
f"{DomainInterface.available_sampling_modes}."
)
@abstractmethod

View File

@@ -39,7 +39,6 @@ class EllipsoidDomain(DomainInterface):
self.range_ = {}
self._centers = None
self._axis = None
self.sample_modes = "random"
# checking consistency
check_consistency(sample_surface, bool)
@@ -72,6 +71,10 @@ class EllipsoidDomain(DomainInterface):
self._centers = dict(zip(self.range_.keys(), centers.tolist()))
self._axis = dict(zip(self.range_.keys(), ellipsoid_axis.tolist()))
@property
def sample_modes(self):
return ["random"]
@property
def variables(self):
"""Spatial variables.

View File

@@ -24,8 +24,9 @@ class OperationInterface(DomainInterface, metaclass=ABCMeta):
# assign geometries
self._geometries = geometries
# sampling mode, for now random is the only available
self.sample_modes = "random"
@property
def sample_modes(self):
return ["random"]
@property
def geometries(self):

View File

@@ -74,9 +74,10 @@ class SimplexDomain(DomainInterface):
# build cartesian_bound
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
# sampling mode
self.sample_modes = "random"
@property
def sample_modes(self):
return ["random"]
@property
def variables(self):
return sorted(self._vertices_matrix.labels)

View File

@@ -32,9 +32,16 @@ class Union(OperationInterface):
"""
super().__init__(geometries)
@property
def sample_modes(self):
self.sample_modes = list(
set([geom.sample_modes for geom in geometries])
set([geom.sample_modes for geom in self.geometries])
)
@property
def variables(self):
return list(set([geom.variables for geom in self.geometries]))
def is_inside(self, point, check_border=False):
"""

View File

@@ -1,7 +1,7 @@
"""Module Averaging Neural Operator."""
import torch
from torch import nn, concatenate
from torch import nn, cat
from .layers import AVNOBlock
from .base_no import KernelNeuralOperator
from pina.utils import check_consistency
@@ -110,9 +110,9 @@ class AveragingNeuralOperator(KernelNeuralOperator):
"""
points_tmp = x.extract(self.coordinates_indices)
new_batch = x.extract(self.field_indices)
new_batch = concatenate((new_batch, points_tmp), dim=-1)
new_batch = cat((new_batch, points_tmp), dim=-1)
new_batch = self._lifting_operator(new_batch)
new_batch = self._integral_kernels(new_batch)
new_batch = concatenate((new_batch, points_tmp), dim=-1)
new_batch = cat((new_batch, points_tmp), dim=-1)
new_batch = self._projection_operator(new_batch)
return new_batch

View File

@@ -1,7 +1,7 @@
"""Module LowRank Neural Operator."""
import torch
from torch import nn, concatenate
from torch import nn, cat
from pina.utils import check_consistency
@@ -145,4 +145,4 @@ class LowRankNeuralOperator(KernelNeuralOperator):
for module in self._integral_kernels:
x = module(x, coords)
# projecting
return self._projection_operator(concatenate((x, coords), dim=-1))
return self._projection_operator(cat((x, coords), dim=-1))

View File

@@ -1,12 +1,11 @@
""" Module for AbstractProblem class """
from abc import ABCMeta, abstractmethod
from ..utils import merge_tensors, check_consistency
from ..utils import check_consistency
from ..domain import DomainInterface
from ..condition.domain_equation_condition import DomainEquationCondition
from ..collector import Collector
from copy import deepcopy
import torch
from .. import LabelTensor
class AbstractProblem(metaclass=ABCMeta):
"""
@@ -20,27 +19,25 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
self._discretized_domains = {}
for name, domain in self.domains.items():
if isinstance(domain, (torch.Tensor, LabelTensor)):
self._discretized_domains[name] = domain
# create collector to manage problem data
self.collector = Collector(self)
# create hook conditions <-> problems
for condition_name in self.conditions:
self.conditions[condition_name].set_problem(self)
self.conditions[condition_name].problem = self
# # variable storing all points
self.input_pts = {}
# store in collector all the available fixed points
# note that some points could not be stored at this stage (e.g. when
# sampling locations). To check that all data points are ready for
# training all type self.collector.full, which returns true if all
# points are ready.
self.collector.store_fixed_data()
# # varible to check if sampling is done. If no location
# # element is presented in Condition this variable is set to true
# self._have_sampled_points = {}
for condition_name in self.conditions:
self._discretized_domains[condition_name] = False
# # put in self.input_pts all the points that we don't need to sample
self._span_condition_points()
@property
def input_pts(self):
return self.collector.data_collections
def __deepcopy__(self, memo):
"""
Implements deepcopy for the
@@ -85,19 +82,6 @@ class AbstractProblem(metaclass=ABCMeta):
def input_variables(self, variables):
raise RuntimeError
@property
@abstractmethod
def domains(self):
"""
The domain(s) where the conditions of the AbstractProblem are valid.
If more than one domain type is passed, a list of Location is
retured.
:return: the domain(s) of ``self``
:rtype: list[Location]
"""
pass
@property
@abstractmethod
def output_variables(self):
@@ -114,34 +98,8 @@ class AbstractProblem(metaclass=ABCMeta):
"""
return self._conditions
def _span_condition_points(self):
"""
Simple function to get the condition points
"""
for condition_name in self.conditions:
condition = self.conditions[condition_name]
if hasattr(condition, "input_points"):
samples = condition.input_points
self.input_pts[condition_name] = samples
self._discretized_domains[condition_name] = True
if hasattr(self, "unknown_parameter_domain"):
# initialize the unknown parameters of the inverse problem given
# the domain the user gives
self.unknown_parameters = {}
for i, var in enumerate(self.unknown_variables):
range_var = self.unknown_parameter_domain.range_[var]
tensor_var = (
torch.rand(1, requires_grad=True) * range_var[1]
+ range_var[0]
)
self.unknown_parameters[var] = torch.nn.Parameter(
tensor_var
)
def discretise_domain(
self, n, mode="random", variables="all", domains="all"
self, n, mode="random", variables="all", locations="all"
):
"""
Generate a set of points to span the `Location` of all the conditions of
@@ -172,119 +130,38 @@ class AbstractProblem(metaclass=ABCMeta):
``CartesianDomain``.
"""
# check consistecy n
# check consistecy n, mode, variables, locations
check_consistency(n, int)
# check consistency mode
check_consistency(mode, str)
if mode not in ["random", "grid", "lh", "chebyshev", "latin"]:
check_consistency(variables, str)
check_consistency(locations, str)
# check correct sampling mode
if mode not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {mode} not valid.")
# check consistency variables
# check correct variables
if variables == "all":
variables = self.input_variables
else:
check_consistency(variables, str)
if sorted(variables) != sorted(self.input_variables):
TypeError(
f"Wrong variables for sampling. Variables ",
f"should be in {self.input_variables}.",
)
# # check consistency location # TODO: check if this is needed (from 0.1)
# locations_to_sample = [
# condition
# for condition in self.conditions
# if hasattr(self.conditions[condition], "location")
# ]
# if locations == "all":
# # only locations that can be sampled
# locations = locations_to_sample
# else:
# check_consistency(locations, str)
# if sorted(locations) != sorted(locations_to_sample):
if domains == "all":
domains = [condition for condition in self.conditions]
else:
check_consistency(domains, str)
print(domains)
if sorted(domains) != sorted(self.conditions):
TypeError(
f"Wrong locations for sampling. Location ",
f"should be in {locations_to_sample}.",
)
# sampling
for d in domains:
condition = self.conditions[d]
# we try to check if we have already sampled
try:
already_sampled = [self.input_pts[d]]
# if we have not sampled, a key error is thrown
except KeyError:
already_sampled = []
# if we have already sampled fully the condition
# but we want to sample again we set already_sampled
# to an empty list since we need to sample again, and
# self._have_sampled_points to False.
if self._discretized_domains[d]:
already_sampled = []
self._discretized_domains[d] = False
print(condition.domain)
print(d)
# build samples
samples = [
self.domains[d].sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
self.input_pts[d] = pts
# the condition is sampled if input_pts contains all labels
if sorted(self.input_pts[d].labels) == sorted(
self.input_variables
):
# self._have_sampled_points[location] = True
# self.input_pts[location] = self.input_pts[location].extract(
# sorted(self.input_variables)
# )
self._have_sampled_points[d] = True
def add_points(self, new_points):
"""
Adding points to the already sampled points.
:param dict new_points: a dictionary with key the location to add the points
and values the torch.Tensor points.
"""
if sorted(new_points.keys()) != sorted(self.conditions):
TypeError(
f"Wrong locations for new points. Location ",
f"should be in {self.conditions}.",
)
for location in new_points.keys():
# extract old and new points
old_pts = self.input_pts[location]
new_pts = new_points[location]
# if they don't have the same variables error
if sorted(old_pts.labels) != sorted(new_pts.labels):
for variable in variables:
if variable not in self.input_variables:
TypeError(
f"Not matching variables for old and new points "
f"in condition {location}."
f"Wrong variables for sampling. Variables ",
f"should be in {self.input_variables}.",
)
if old_pts.labels != new_pts.labels:
new_pts = torch.hstack(
[new_pts.extract([i]) for i in old_pts.labels]
)
new_pts.labels = old_pts.labels
# merging
merged_pts = torch.vstack([old_pts, new_pts])
merged_pts.labels = old_pts.labels
self.input_pts[location] = merged_pts
# check correct location
if locations == "all":
locations = [name for name in self.conditions.keys()]
else:
if not isinstance(locations, (list)):
locations = [locations]
for loc in locations:
if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError(
f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
)
# store data
self.collector.store_sample_domains(n, mode, variables, locations)

View File

@@ -45,6 +45,20 @@ class InverseProblem(AbstractProblem):
>>> 'data': Condition(CartesianDomain({'x': [0, 1]}), Equation(solution_data))
"""
def __init__(self):
super().__init__()
# storing unknown_parameters for optimization
self.unknown_parameters = {}
for i, var in enumerate(self.unknown_variables):
range_var = self.unknown_parameter_domain.range_[var]
tensor_var = (
torch.rand(1, requires_grad=True) * range_var[1]
+ range_var[0]
)
self.unknown_parameters[var] = torch.nn.Parameter(
tensor_var
)
@abstractmethod
def unknown_parameter_domain(self):
"""