clean logic, fix problems for tutorial1
This commit is contained in:
@@ -62,46 +62,58 @@ class Collector:
|
||||
# condition now is ready
|
||||
self._is_conditions_ready[condition_name] = True
|
||||
|
||||
def store_sample_domains(self, n, mode, variables, sample_locations):
|
||||
# loop over all locations
|
||||
for loc in sample_locations:
|
||||
# get condition
|
||||
condition = self.problem.conditions[loc]
|
||||
condition_domain = condition.domain
|
||||
if isinstance(condition_domain, str):
|
||||
condition_domain = self.problem.domains[condition_domain]
|
||||
keys = ["input_points", "equation"]
|
||||
# if the condition is not ready, we get and store the data
|
||||
if not self._is_conditions_ready[loc]:
|
||||
# if it is the first time we sample
|
||||
if not self.data_collections[loc]:
|
||||
already_sampled = []
|
||||
# if we have sampled the condition but not all variables
|
||||
else:
|
||||
already_sampled = [
|
||||
self.data_collections[loc]['input_points']
|
||||
]
|
||||
# if the condition is ready but we want to sample again
|
||||
else:
|
||||
self._is_conditions_ready[loc] = False
|
||||
already_sampled = []
|
||||
def store_sample_domains(self):
|
||||
"""
|
||||
Add
|
||||
"""
|
||||
for condition_name in self.problem.conditions:
|
||||
condition = self.problem.conditions[condition_name]
|
||||
if not hasattr(condition, "domain"):
|
||||
continue
|
||||
|
||||
# get the samples
|
||||
samples = [
|
||||
condition_domain.sample(n=n, mode=mode,
|
||||
variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
|
||||
pts = pts.sort_labels()
|
||||
if sorted(pts.labels) == sorted(self.problem.input_variables):
|
||||
self._is_conditions_ready[loc] = True
|
||||
values = [pts, condition.equation]
|
||||
self.data_collections[loc] = dict(zip(keys, values))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Try to sample variables which are not in problem defined '
|
||||
'in the problem')
|
||||
samples = self.problem.discretised_domains[condition.domain]
|
||||
|
||||
self.data_collections[condition_name] = {
|
||||
'input_points': samples,
|
||||
'equation': condition.equation
|
||||
}
|
||||
|
||||
# # get condition
|
||||
# condition = self.problem.conditions[loc]
|
||||
# condition_domain = condition.domain
|
||||
# if isinstance(condition_domain, str):
|
||||
# condition_domain = self.problem.domains[condition_domain]
|
||||
# keys = ["input_points", "equation"]
|
||||
# # if the condition is not ready, we get and store the data
|
||||
# if not self._is_conditions_ready[loc]:
|
||||
# # if it is the first time we sample
|
||||
# if not self.data_collections[loc]:
|
||||
# already_sampled = []
|
||||
# # if we have sampled the condition but not all variables
|
||||
# else:
|
||||
# already_sampled = [
|
||||
# self.data_collections[loc]['input_points']
|
||||
# ]
|
||||
# # if the condition is ready but we want to sample again
|
||||
# else:
|
||||
# self._is_conditions_ready[loc] = False
|
||||
# already_sampled = []
|
||||
# # get the samples
|
||||
# samples = [
|
||||
# condition_domain.sample(n=n, mode=mode,
|
||||
# variables=variables)
|
||||
# ] + already_sampled
|
||||
# pts = merge_tensors(samples)
|
||||
# if set(pts.labels).issubset(sorted(self.problem.input_variables)):
|
||||
# pts = pts.sort_labels()
|
||||
# if sorted(pts.labels) == sorted(self.problem.input_variables):
|
||||
# self._is_conditions_ready[loc] = True
|
||||
# values = [pts, condition.equation]
|
||||
# self.data_collections[loc] = dict(zip(keys, values))
|
||||
# else:
|
||||
# raise RuntimeError(
|
||||
# 'Try to sample variables which are not in problem defined '
|
||||
# 'in the problem')
|
||||
|
||||
def add_points(self, new_points_dict):
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
||||
RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from .dataset import PinaDatasetFactory
|
||||
|
||||
from ..collector import Collector
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, dataset, device):
|
||||
@@ -87,7 +87,7 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
collector,
|
||||
problem,
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
val_size=.1,
|
||||
@@ -99,7 +99,6 @@ class PinaDataModule(LightningDataModule):
|
||||
):
|
||||
"""
|
||||
Initialize the object, creating dataset based on input problem
|
||||
:param Collector collector: PINA problem
|
||||
:param train_size: number/percentage of elements in train split
|
||||
:param test_size: number/percentage of elements in test split
|
||||
:param val_size: number/percentage of elements in evaluation split
|
||||
@@ -135,6 +134,10 @@ class PinaDataModule(LightningDataModule):
|
||||
self.predict_dataset = None
|
||||
else:
|
||||
self.predict_dataloader = super().predict_dataloader
|
||||
|
||||
collector = Collector(problem)
|
||||
collector.store_fixed_data()
|
||||
collector.store_sample_domains()
|
||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||
automatic_batching):
|
||||
super().__init__(conditions_dict, max_conditions_lengths)
|
||||
|
||||
if automatic_batching:
|
||||
self._getitem_func = self._getitem_int
|
||||
else:
|
||||
|
||||
@@ -20,8 +20,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.discretised_domains = {}
|
||||
# create collector to manage problem data
|
||||
self._collector = Collector(self)
|
||||
|
||||
# create hook conditions <-> problems
|
||||
for condition_name in self.conditions:
|
||||
@@ -32,12 +32,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# sampling locations). To check that all data points are ready for
|
||||
# training all type self.collector.full, which returns true if all
|
||||
# points are ready.
|
||||
self.collector.store_fixed_data()
|
||||
# self.collector.store_fixed_data()
|
||||
self._batching_dimension = 0
|
||||
|
||||
@property
|
||||
def collector(self):
|
||||
return self._collector
|
||||
# @property
|
||||
# def collector(self):
|
||||
# return self._collector
|
||||
|
||||
@property
|
||||
def batching_dimension(self):
|
||||
@@ -74,6 +74,21 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
setattr(result, k, deepcopy(v, memo))
|
||||
return result
|
||||
|
||||
@property
|
||||
def are_all_domains_discretised(self):
|
||||
"""
|
||||
Check if all the domains are discretised.
|
||||
|
||||
:return: True if all the domains are discretised, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
return all(
|
||||
[
|
||||
domain in self.discretised_domains
|
||||
for domain in self.domains.keys()
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
"""
|
||||
@@ -120,7 +135,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
n,
|
||||
mode="random",
|
||||
variables="all",
|
||||
locations="all"):
|
||||
domains="all"):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
@@ -134,12 +149,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
||||
:param variables: problem's variables to be sampled, defaults to 'all'.
|
||||
:type variables: str | list[str]
|
||||
:param locations: problem's locations from where to sample, defaults to 'all'.
|
||||
:param domain: problem's domain from where to sample, defaults to 'all'.
|
||||
:type locations: str
|
||||
|
||||
:Example:
|
||||
>>> pinn.discretise_domain(n=10, mode='grid')
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1'])
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1'])
|
||||
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
|
||||
|
||||
.. warning::
|
||||
@@ -154,11 +169,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
check_consistency(n, int)
|
||||
check_consistency(mode, str)
|
||||
check_consistency(variables, str)
|
||||
check_consistency(locations, str)
|
||||
check_consistency(domains, (list, str))
|
||||
|
||||
# check correct sampling mode
|
||||
if mode not in DomainInterface.available_sampling_modes:
|
||||
raise TypeError(f"mode {mode} not valid.")
|
||||
# if mode not in DomainInterface.available_sampling_modes:
|
||||
# raise TypeError(f"mode {mode} not valid.")
|
||||
|
||||
# check correct variables
|
||||
if variables == "all":
|
||||
@@ -170,23 +185,21 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
# check correct location
|
||||
if locations == "all":
|
||||
locations = [
|
||||
name for name in self.conditions.keys()
|
||||
if isinstance(self.conditions[name], DomainEquationCondition)
|
||||
]
|
||||
else:
|
||||
if not isinstance(locations, (list)):
|
||||
locations = [locations]
|
||||
for loc in locations:
|
||||
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
raise TypeError(
|
||||
f"Wrong locations passed, locations for sampling "
|
||||
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
)
|
||||
if domains == "all":
|
||||
domains = self.domains.keys()
|
||||
elif not isinstance(domains, (list)):
|
||||
domains = [domains]
|
||||
|
||||
for domain in domains:
|
||||
self.discretised_domains[domain] = (
|
||||
self.domains[domain].sample(n, mode, variables)
|
||||
)
|
||||
# if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
# raise TypeError(
|
||||
# f"Wrong locations passed, locations for sampling "
|
||||
# f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
# )
|
||||
|
||||
# store data
|
||||
self.collector.store_sample_domains(n, mode, variables, locations)
|
||||
|
||||
def add_points(self, new_points_dict):
|
||||
self.collector.add_points(new_points_dict)
|
||||
# self.collector.store_sample_domains()
|
||||
# self.collector.store_sample_domains(n, mode, variables, domain)
|
||||
@@ -63,17 +63,17 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
during training, there is no need to define to touch the
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
if not self.solver.problem.collector.full:
|
||||
if not self.solver.problem.are_all_domains_discretised:
|
||||
error_message = '\n'.join([
|
||||
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
||||
"not sampled"}""" for key, value in
|
||||
self._solver.problem.collector._is_conditions_ready.items()
|
||||
f"""{" " * 13} ---> Domain {key} {"sampled" if key in self.solver.problem.discretised_domains else
|
||||
"not sampled"}""" for key in
|
||||
self.solver.problem.domains.keys()
|
||||
])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
automatic_batching = False
|
||||
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
|
||||
self.data_module = PinaDataModule(self.solver.problem,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.val_size,
|
||||
|
||||
Reference in New Issue
Block a user