* Adding Collector for handling data sampling/collection before dataset/dataloader

* Modify domain by adding sample_mode, variables as property
* Small change concatenate -> cat in lno/avno
* Create different factory classes for conditions
This commit is contained in:
Dario Coscia
2024-10-04 13:57:18 +02:00
committed by Nicola Demo
parent aef5a5d590
commit 1bd3f40f54
18 changed files with 225 additions and 277 deletions

View File

@@ -1,12 +1,11 @@
""" Module for AbstractProblem class """
from abc import ABCMeta, abstractmethod
from ..utils import merge_tensors, check_consistency
from ..utils import check_consistency
from ..domain import DomainInterface
from ..condition.domain_equation_condition import DomainEquationCondition
from ..collector import Collector
from copy import deepcopy
import torch
from .. import LabelTensor
class AbstractProblem(metaclass=ABCMeta):
"""
@@ -20,27 +19,25 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
self._discretized_domains = {}
for name, domain in self.domains.items():
if isinstance(domain, (torch.Tensor, LabelTensor)):
self._discretized_domains[name] = domain
# create collector to manage problem data
self.collector = Collector(self)
# create hook conditions <-> problems
for condition_name in self.conditions:
self.conditions[condition_name].set_problem(self)
self.conditions[condition_name].problem = self
# # variable storing all points
self.input_pts = {}
# store in collector all the available fixed points
# note that some points could not be stored at this stage (e.g. when
# sampling locations). To check that all data points are ready for
# training all type self.collector.full, which returns true if all
# points are ready.
self.collector.store_fixed_data()
# # varible to check if sampling is done. If no location
# # element is presented in Condition this variable is set to true
# self._have_sampled_points = {}
for condition_name in self.conditions:
self._discretized_domains[condition_name] = False
# # put in self.input_pts all the points that we don't need to sample
self._span_condition_points()
@property
def input_pts(self):
return self.collector.data_collections
def __deepcopy__(self, memo):
"""
Implements deepcopy for the
@@ -85,19 +82,6 @@ class AbstractProblem(metaclass=ABCMeta):
def input_variables(self, variables):
raise RuntimeError
@property
@abstractmethod
def domains(self):
"""
The domain(s) where the conditions of the AbstractProblem are valid.
If more than one domain type is passed, a list of Location is
retured.
:return: the domain(s) of ``self``
:rtype: list[Location]
"""
pass
@property
@abstractmethod
def output_variables(self):
@@ -114,34 +98,8 @@ class AbstractProblem(metaclass=ABCMeta):
"""
return self._conditions
def _span_condition_points(self):
"""
Simple function to get the condition points
"""
for condition_name in self.conditions:
condition = self.conditions[condition_name]
if hasattr(condition, "input_points"):
samples = condition.input_points
self.input_pts[condition_name] = samples
self._discretized_domains[condition_name] = True
if hasattr(self, "unknown_parameter_domain"):
# initialize the unknown parameters of the inverse problem given
# the domain the user gives
self.unknown_parameters = {}
for i, var in enumerate(self.unknown_variables):
range_var = self.unknown_parameter_domain.range_[var]
tensor_var = (
torch.rand(1, requires_grad=True) * range_var[1]
+ range_var[0]
)
self.unknown_parameters[var] = torch.nn.Parameter(
tensor_var
)
def discretise_domain(
self, n, mode="random", variables="all", domains="all"
self, n, mode="random", variables="all", locations="all"
):
"""
Generate a set of points to span the `Location` of all the conditions of
@@ -172,119 +130,38 @@ class AbstractProblem(metaclass=ABCMeta):
``CartesianDomain``.
"""
# check consistecy n
# check consistecy n, mode, variables, locations
check_consistency(n, int)
# check consistency mode
check_consistency(mode, str)
if mode not in ["random", "grid", "lh", "chebyshev", "latin"]:
check_consistency(variables, str)
check_consistency(locations, str)
# check correct sampling mode
if mode not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {mode} not valid.")
# check consistency variables
# check correct variables
if variables == "all":
variables = self.input_variables
else:
check_consistency(variables, str)
if sorted(variables) != sorted(self.input_variables):
TypeError(
f"Wrong variables for sampling. Variables ",
f"should be in {self.input_variables}.",
)
# # check consistency location # TODO: check if this is needed (from 0.1)
# locations_to_sample = [
# condition
# for condition in self.conditions
# if hasattr(self.conditions[condition], "location")
# ]
# if locations == "all":
# # only locations that can be sampled
# locations = locations_to_sample
# else:
# check_consistency(locations, str)
# if sorted(locations) != sorted(locations_to_sample):
if domains == "all":
domains = [condition for condition in self.conditions]
else:
check_consistency(domains, str)
print(domains)
if sorted(domains) != sorted(self.conditions):
TypeError(
f"Wrong locations for sampling. Location ",
f"should be in {locations_to_sample}.",
)
# sampling
for d in domains:
condition = self.conditions[d]
# we try to check if we have already sampled
try:
already_sampled = [self.input_pts[d]]
# if we have not sampled, a key error is thrown
except KeyError:
already_sampled = []
# if we have already sampled fully the condition
# but we want to sample again we set already_sampled
# to an empty list since we need to sample again, and
# self._have_sampled_points to False.
if self._discretized_domains[d]:
already_sampled = []
self._discretized_domains[d] = False
print(condition.domain)
print(d)
# build samples
samples = [
self.domains[d].sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
self.input_pts[d] = pts
# the condition is sampled if input_pts contains all labels
if sorted(self.input_pts[d].labels) == sorted(
self.input_variables
):
# self._have_sampled_points[location] = True
# self.input_pts[location] = self.input_pts[location].extract(
# sorted(self.input_variables)
# )
self._have_sampled_points[d] = True
def add_points(self, new_points):
"""
Adding points to the already sampled points.
:param dict new_points: a dictionary with key the location to add the points
and values the torch.Tensor points.
"""
if sorted(new_points.keys()) != sorted(self.conditions):
TypeError(
f"Wrong locations for new points. Location ",
f"should be in {self.conditions}.",
)
for location in new_points.keys():
# extract old and new points
old_pts = self.input_pts[location]
new_pts = new_points[location]
# if they don't have the same variables error
if sorted(old_pts.labels) != sorted(new_pts.labels):
for variable in variables:
if variable not in self.input_variables:
TypeError(
f"Not matching variables for old and new points "
f"in condition {location}."
f"Wrong variables for sampling. Variables ",
f"should be in {self.input_variables}.",
)
if old_pts.labels != new_pts.labels:
new_pts = torch.hstack(
[new_pts.extract([i]) for i in old_pts.labels]
)
new_pts.labels = old_pts.labels
# merging
merged_pts = torch.vstack([old_pts, new_pts])
merged_pts.labels = old_pts.labels
self.input_pts[location] = merged_pts
# check correct location
if locations == "all":
locations = [name for name in self.conditions.keys()]
else:
if not isinstance(locations, (list)):
locations = [locations]
for loc in locations:
if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError(
f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
)
# store data
self.collector.store_sample_domains(n, mode, variables, locations)