* Adding Collector for handling data sampling/collection before dataset/dataloader
* Modify domain by adding sample_mode, variables as property * Small change concatenate -> cat in lno/avno * Create different factory classes for conditions
This commit is contained in:
committed by
Nicola Demo
parent
aef5a5d590
commit
1bd3f40f54
@@ -1,12 +1,11 @@
|
||||
""" Module for AbstractProblem class """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..utils import merge_tensors, check_consistency
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface
|
||||
from ..condition.domain_equation_condition import DomainEquationCondition
|
||||
from ..collector import Collector
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
from .. import LabelTensor
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
@@ -20,27 +19,25 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self._discretized_domains = {}
|
||||
|
||||
for name, domain in self.domains.items():
|
||||
if isinstance(domain, (torch.Tensor, LabelTensor)):
|
||||
self._discretized_domains[name] = domain
|
||||
# create collector to manage problem data
|
||||
self.collector = Collector(self)
|
||||
|
||||
# create hook conditions <-> problems
|
||||
for condition_name in self.conditions:
|
||||
self.conditions[condition_name].set_problem(self)
|
||||
self.conditions[condition_name].problem = self
|
||||
|
||||
# # variable storing all points
|
||||
self.input_pts = {}
|
||||
# store in collector all the available fixed points
|
||||
# note that some points could not be stored at this stage (e.g. when
|
||||
# sampling locations). To check that all data points are ready for
|
||||
# training all type self.collector.full, which returns true if all
|
||||
# points are ready.
|
||||
self.collector.store_fixed_data()
|
||||
|
||||
# # varible to check if sampling is done. If no location
|
||||
# # element is presented in Condition this variable is set to true
|
||||
# self._have_sampled_points = {}
|
||||
for condition_name in self.conditions:
|
||||
self._discretized_domains[condition_name] = False
|
||||
|
||||
# # put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
|
||||
@property
|
||||
def input_pts(self):
|
||||
return self.collector.data_collections
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
Implements deepcopy for the
|
||||
@@ -85,19 +82,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
def input_variables(self, variables):
|
||||
raise RuntimeError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def domains(self):
|
||||
"""
|
||||
The domain(s) where the conditions of the AbstractProblem are valid.
|
||||
If more than one domain type is passed, a list of Location is
|
||||
retured.
|
||||
|
||||
:return: the domain(s) of ``self``
|
||||
:rtype: list[Location]
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_variables(self):
|
||||
@@ -114,34 +98,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
return self._conditions
|
||||
|
||||
|
||||
|
||||
def _span_condition_points(self):
|
||||
"""
|
||||
Simple function to get the condition points
|
||||
"""
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
if hasattr(condition, "input_points"):
|
||||
samples = condition.input_points
|
||||
self.input_pts[condition_name] = samples
|
||||
self._discretized_domains[condition_name] = True
|
||||
if hasattr(self, "unknown_parameter_domain"):
|
||||
# initialize the unknown parameters of the inverse problem given
|
||||
# the domain the user gives
|
||||
self.unknown_parameters = {}
|
||||
for i, var in enumerate(self.unknown_variables):
|
||||
range_var = self.unknown_parameter_domain.range_[var]
|
||||
tensor_var = (
|
||||
torch.rand(1, requires_grad=True) * range_var[1]
|
||||
+ range_var[0]
|
||||
)
|
||||
self.unknown_parameters[var] = torch.nn.Parameter(
|
||||
tensor_var
|
||||
)
|
||||
|
||||
def discretise_domain(
|
||||
self, n, mode="random", variables="all", domains="all"
|
||||
self, n, mode="random", variables="all", locations="all"
|
||||
):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
@@ -172,119 +130,38 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
``CartesianDomain``.
|
||||
"""
|
||||
|
||||
# check consistecy n
|
||||
# check consistecy n, mode, variables, locations
|
||||
check_consistency(n, int)
|
||||
|
||||
# check consistency mode
|
||||
check_consistency(mode, str)
|
||||
if mode not in ["random", "grid", "lh", "chebyshev", "latin"]:
|
||||
check_consistency(variables, str)
|
||||
check_consistency(locations, str)
|
||||
|
||||
# check correct sampling mode
|
||||
if mode not in DomainInterface.available_sampling_modes:
|
||||
raise TypeError(f"mode {mode} not valid.")
|
||||
|
||||
# check consistency variables
|
||||
# check correct variables
|
||||
if variables == "all":
|
||||
variables = self.input_variables
|
||||
else:
|
||||
check_consistency(variables, str)
|
||||
|
||||
if sorted(variables) != sorted(self.input_variables):
|
||||
TypeError(
|
||||
f"Wrong variables for sampling. Variables ",
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
|
||||
# # check consistency location # TODO: check if this is needed (from 0.1)
|
||||
# locations_to_sample = [
|
||||
# condition
|
||||
# for condition in self.conditions
|
||||
# if hasattr(self.conditions[condition], "location")
|
||||
# ]
|
||||
# if locations == "all":
|
||||
# # only locations that can be sampled
|
||||
# locations = locations_to_sample
|
||||
# else:
|
||||
# check_consistency(locations, str)
|
||||
|
||||
# if sorted(locations) != sorted(locations_to_sample):
|
||||
if domains == "all":
|
||||
domains = [condition for condition in self.conditions]
|
||||
else:
|
||||
check_consistency(domains, str)
|
||||
print(domains)
|
||||
if sorted(domains) != sorted(self.conditions):
|
||||
TypeError(
|
||||
f"Wrong locations for sampling. Location ",
|
||||
f"should be in {locations_to_sample}.",
|
||||
)
|
||||
|
||||
# sampling
|
||||
for d in domains:
|
||||
condition = self.conditions[d]
|
||||
|
||||
# we try to check if we have already sampled
|
||||
try:
|
||||
already_sampled = [self.input_pts[d]]
|
||||
# if we have not sampled, a key error is thrown
|
||||
except KeyError:
|
||||
already_sampled = []
|
||||
|
||||
# if we have already sampled fully the condition
|
||||
# but we want to sample again we set already_sampled
|
||||
# to an empty list since we need to sample again, and
|
||||
# self._have_sampled_points to False.
|
||||
if self._discretized_domains[d]:
|
||||
already_sampled = []
|
||||
self._discretized_domains[d] = False
|
||||
print(condition.domain)
|
||||
print(d)
|
||||
# build samples
|
||||
samples = [
|
||||
self.domains[d].sample(n=n, mode=mode, variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
self.input_pts[d] = pts
|
||||
|
||||
# the condition is sampled if input_pts contains all labels
|
||||
if sorted(self.input_pts[d].labels) == sorted(
|
||||
self.input_variables
|
||||
):
|
||||
# self._have_sampled_points[location] = True
|
||||
# self.input_pts[location] = self.input_pts[location].extract(
|
||||
# sorted(self.input_variables)
|
||||
# )
|
||||
self._have_sampled_points[d] = True
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
Adding points to the already sampled points.
|
||||
|
||||
:param dict new_points: a dictionary with key the location to add the points
|
||||
and values the torch.Tensor points.
|
||||
"""
|
||||
|
||||
if sorted(new_points.keys()) != sorted(self.conditions):
|
||||
TypeError(
|
||||
f"Wrong locations for new points. Location ",
|
||||
f"should be in {self.conditions}.",
|
||||
)
|
||||
|
||||
for location in new_points.keys():
|
||||
# extract old and new points
|
||||
old_pts = self.input_pts[location]
|
||||
new_pts = new_points[location]
|
||||
|
||||
# if they don't have the same variables error
|
||||
if sorted(old_pts.labels) != sorted(new_pts.labels):
|
||||
for variable in variables:
|
||||
if variable not in self.input_variables:
|
||||
TypeError(
|
||||
f"Not matching variables for old and new points "
|
||||
f"in condition {location}."
|
||||
f"Wrong variables for sampling. Variables ",
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
if old_pts.labels != new_pts.labels:
|
||||
new_pts = torch.hstack(
|
||||
[new_pts.extract([i]) for i in old_pts.labels]
|
||||
)
|
||||
new_pts.labels = old_pts.labels
|
||||
|
||||
# merging
|
||||
merged_pts = torch.vstack([old_pts, new_pts])
|
||||
merged_pts.labels = old_pts.labels
|
||||
self.input_pts[location] = merged_pts
|
||||
# check correct location
|
||||
if locations == "all":
|
||||
locations = [name for name in self.conditions.keys()]
|
||||
else:
|
||||
if not isinstance(locations, (list)):
|
||||
locations = [locations]
|
||||
for loc in locations:
|
||||
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
raise TypeError(
|
||||
f"Wrong locations passed, locations for sampling "
|
||||
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
)
|
||||
|
||||
# store data
|
||||
self.collector.store_sample_domains(n, mode, variables, locations)
|
||||
|
||||
@@ -45,6 +45,20 @@ class InverseProblem(AbstractProblem):
|
||||
>>> 'data': Condition(CartesianDomain({'x': [0, 1]}), Equation(solution_data))
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# storing unknown_parameters for optimization
|
||||
self.unknown_parameters = {}
|
||||
for i, var in enumerate(self.unknown_variables):
|
||||
range_var = self.unknown_parameter_domain.range_[var]
|
||||
tensor_var = (
|
||||
torch.rand(1, requires_grad=True) * range_var[1]
|
||||
+ range_var[0]
|
||||
)
|
||||
self.unknown_parameters[var] = torch.nn.Parameter(
|
||||
tensor_var
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def unknown_parameter_domain(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user