clean logic, fix problems for tutorial1
This commit is contained in:
@@ -62,46 +62,58 @@ class Collector:
|
|||||||
# condition now is ready
|
# condition now is ready
|
||||||
self._is_conditions_ready[condition_name] = True
|
self._is_conditions_ready[condition_name] = True
|
||||||
|
|
||||||
def store_sample_domains(self, n, mode, variables, sample_locations):
|
def store_sample_domains(self):
|
||||||
# loop over all locations
|
"""
|
||||||
for loc in sample_locations:
|
Add
|
||||||
# get condition
|
"""
|
||||||
condition = self.problem.conditions[loc]
|
for condition_name in self.problem.conditions:
|
||||||
condition_domain = condition.domain
|
condition = self.problem.conditions[condition_name]
|
||||||
if isinstance(condition_domain, str):
|
if not hasattr(condition, "domain"):
|
||||||
condition_domain = self.problem.domains[condition_domain]
|
continue
|
||||||
keys = ["input_points", "equation"]
|
|
||||||
# if the condition is not ready, we get and store the data
|
|
||||||
if not self._is_conditions_ready[loc]:
|
|
||||||
# if it is the first time we sample
|
|
||||||
if not self.data_collections[loc]:
|
|
||||||
already_sampled = []
|
|
||||||
# if we have sampled the condition but not all variables
|
|
||||||
else:
|
|
||||||
already_sampled = [
|
|
||||||
self.data_collections[loc]['input_points']
|
|
||||||
]
|
|
||||||
# if the condition is ready but we want to sample again
|
|
||||||
else:
|
|
||||||
self._is_conditions_ready[loc] = False
|
|
||||||
already_sampled = []
|
|
||||||
|
|
||||||
# get the samples
|
samples = self.problem.discretised_domains[condition.domain]
|
||||||
samples = [
|
|
||||||
condition_domain.sample(n=n, mode=mode,
|
self.data_collections[condition_name] = {
|
||||||
variables=variables)
|
'input_points': samples,
|
||||||
] + already_sampled
|
'equation': condition.equation
|
||||||
pts = merge_tensors(samples)
|
}
|
||||||
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
|
|
||||||
pts = pts.sort_labels()
|
# # get condition
|
||||||
if sorted(pts.labels) == sorted(self.problem.input_variables):
|
# condition = self.problem.conditions[loc]
|
||||||
self._is_conditions_ready[loc] = True
|
# condition_domain = condition.domain
|
||||||
values = [pts, condition.equation]
|
# if isinstance(condition_domain, str):
|
||||||
self.data_collections[loc] = dict(zip(keys, values))
|
# condition_domain = self.problem.domains[condition_domain]
|
||||||
else:
|
# keys = ["input_points", "equation"]
|
||||||
raise RuntimeError(
|
# # if the condition is not ready, we get and store the data
|
||||||
'Try to sample variables which are not in problem defined '
|
# if not self._is_conditions_ready[loc]:
|
||||||
'in the problem')
|
# # if it is the first time we sample
|
||||||
|
# if not self.data_collections[loc]:
|
||||||
|
# already_sampled = []
|
||||||
|
# # if we have sampled the condition but not all variables
|
||||||
|
# else:
|
||||||
|
# already_sampled = [
|
||||||
|
# self.data_collections[loc]['input_points']
|
||||||
|
# ]
|
||||||
|
# # if the condition is ready but we want to sample again
|
||||||
|
# else:
|
||||||
|
# self._is_conditions_ready[loc] = False
|
||||||
|
# already_sampled = []
|
||||||
|
# # get the samples
|
||||||
|
# samples = [
|
||||||
|
# condition_domain.sample(n=n, mode=mode,
|
||||||
|
# variables=variables)
|
||||||
|
# ] + already_sampled
|
||||||
|
# pts = merge_tensors(samples)
|
||||||
|
# if set(pts.labels).issubset(sorted(self.problem.input_variables)):
|
||||||
|
# pts = pts.sort_labels()
|
||||||
|
# if sorted(pts.labels) == sorted(self.problem.input_variables):
|
||||||
|
# self._is_conditions_ready[loc] = True
|
||||||
|
# values = [pts, condition.equation]
|
||||||
|
# self.data_collections[loc] = dict(zip(keys, values))
|
||||||
|
# else:
|
||||||
|
# raise RuntimeError(
|
||||||
|
# 'Try to sample variables which are not in problem defined '
|
||||||
|
# 'in the problem')
|
||||||
|
|
||||||
def add_points(self, new_points_dict):
|
def add_points(self, new_points_dict):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
|||||||
RandomSampler
|
RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from .dataset import PinaDatasetFactory
|
from .dataset import PinaDatasetFactory
|
||||||
|
from ..collector import Collector
|
||||||
|
|
||||||
class DummyDataloader:
|
class DummyDataloader:
|
||||||
def __init__(self, dataset, device):
|
def __init__(self, dataset, device):
|
||||||
@@ -87,7 +87,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
collector,
|
problem,
|
||||||
train_size=.7,
|
train_size=.7,
|
||||||
test_size=.2,
|
test_size=.2,
|
||||||
val_size=.1,
|
val_size=.1,
|
||||||
@@ -99,7 +99,6 @@ class PinaDataModule(LightningDataModule):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the object, creating dataset based on input problem
|
Initialize the object, creating dataset based on input problem
|
||||||
:param Collector collector: PINA problem
|
|
||||||
:param train_size: number/percentage of elements in train split
|
:param train_size: number/percentage of elements in train split
|
||||||
:param test_size: number/percentage of elements in test split
|
:param test_size: number/percentage of elements in test split
|
||||||
:param val_size: number/percentage of elements in evaluation split
|
:param val_size: number/percentage of elements in evaluation split
|
||||||
@@ -135,6 +134,10 @@ class PinaDataModule(LightningDataModule):
|
|||||||
self.predict_dataset = None
|
self.predict_dataset = None
|
||||||
else:
|
else:
|
||||||
self.predict_dataloader = super().predict_dataloader
|
self.predict_dataloader = super().predict_dataloader
|
||||||
|
|
||||||
|
collector = Collector(problem)
|
||||||
|
collector.store_fixed_data()
|
||||||
|
collector.store_sample_domains()
|
||||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||||
self.transfer_batch_to_device = self._transfer_batch_to_device
|
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||||
automatic_batching):
|
automatic_batching):
|
||||||
super().__init__(conditions_dict, max_conditions_lengths)
|
super().__init__(conditions_dict, max_conditions_lengths)
|
||||||
|
|
||||||
if automatic_batching:
|
if automatic_batching:
|
||||||
self._getitem_func = self._getitem_int
|
self._getitem_func = self._getitem_int
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
|
self.discretised_domains = {}
|
||||||
# create collector to manage problem data
|
# create collector to manage problem data
|
||||||
self._collector = Collector(self)
|
|
||||||
|
|
||||||
# create hook conditions <-> problems
|
# create hook conditions <-> problems
|
||||||
for condition_name in self.conditions:
|
for condition_name in self.conditions:
|
||||||
@@ -32,12 +32,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
# sampling locations). To check that all data points are ready for
|
# sampling locations). To check that all data points are ready for
|
||||||
# training all type self.collector.full, which returns true if all
|
# training all type self.collector.full, which returns true if all
|
||||||
# points are ready.
|
# points are ready.
|
||||||
self.collector.store_fixed_data()
|
# self.collector.store_fixed_data()
|
||||||
self._batching_dimension = 0
|
self._batching_dimension = 0
|
||||||
|
|
||||||
@property
|
# @property
|
||||||
def collector(self):
|
# def collector(self):
|
||||||
return self._collector
|
# return self._collector
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batching_dimension(self):
|
def batching_dimension(self):
|
||||||
@@ -74,6 +74,21 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
setattr(result, k, deepcopy(v, memo))
|
setattr(result, k, deepcopy(v, memo))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def are_all_domains_discretised(self):
|
||||||
|
"""
|
||||||
|
Check if all the domains are discretised.
|
||||||
|
|
||||||
|
:return: True if all the domains are discretised, False otherwise
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
domain in self.discretised_domains
|
||||||
|
for domain in self.domains.keys()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self):
|
def input_variables(self):
|
||||||
"""
|
"""
|
||||||
@@ -120,7 +135,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
n,
|
n,
|
||||||
mode="random",
|
mode="random",
|
||||||
variables="all",
|
variables="all",
|
||||||
locations="all"):
|
domains="all"):
|
||||||
"""
|
"""
|
||||||
Generate a set of points to span the `Location` of all the conditions of
|
Generate a set of points to span the `Location` of all the conditions of
|
||||||
the problem.
|
the problem.
|
||||||
@@ -134,12 +149,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
||||||
:param variables: problem's variables to be sampled, defaults to 'all'.
|
:param variables: problem's variables to be sampled, defaults to 'all'.
|
||||||
:type variables: str | list[str]
|
:type variables: str | list[str]
|
||||||
:param locations: problem's locations from where to sample, defaults to 'all'.
|
:param domain: problem's domain from where to sample, defaults to 'all'.
|
||||||
:type locations: str
|
:type locations: str
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> pinn.discretise_domain(n=10, mode='grid')
|
>>> pinn.discretise_domain(n=10, mode='grid')
|
||||||
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1'])
|
>>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1'])
|
||||||
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
|
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
@@ -154,11 +169,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
check_consistency(n, int)
|
check_consistency(n, int)
|
||||||
check_consistency(mode, str)
|
check_consistency(mode, str)
|
||||||
check_consistency(variables, str)
|
check_consistency(variables, str)
|
||||||
check_consistency(locations, str)
|
check_consistency(domains, (list, str))
|
||||||
|
|
||||||
# check correct sampling mode
|
# check correct sampling mode
|
||||||
if mode not in DomainInterface.available_sampling_modes:
|
# if mode not in DomainInterface.available_sampling_modes:
|
||||||
raise TypeError(f"mode {mode} not valid.")
|
# raise TypeError(f"mode {mode} not valid.")
|
||||||
|
|
||||||
# check correct variables
|
# check correct variables
|
||||||
if variables == "all":
|
if variables == "all":
|
||||||
@@ -170,23 +185,21 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
f"should be in {self.input_variables}.",
|
f"should be in {self.input_variables}.",
|
||||||
)
|
)
|
||||||
# check correct location
|
# check correct location
|
||||||
if locations == "all":
|
if domains == "all":
|
||||||
locations = [
|
domains = self.domains.keys()
|
||||||
name for name in self.conditions.keys()
|
elif not isinstance(domains, (list)):
|
||||||
if isinstance(self.conditions[name], DomainEquationCondition)
|
domains = [domains]
|
||||||
]
|
|
||||||
else:
|
for domain in domains:
|
||||||
if not isinstance(locations, (list)):
|
self.discretised_domains[domain] = (
|
||||||
locations = [locations]
|
self.domains[domain].sample(n, mode, variables)
|
||||||
for loc in locations:
|
)
|
||||||
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
# if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||||
raise TypeError(
|
# raise TypeError(
|
||||||
f"Wrong locations passed, locations for sampling "
|
# f"Wrong locations passed, locations for sampling "
|
||||||
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
# f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||||
)
|
# )
|
||||||
|
|
||||||
# store data
|
# store data
|
||||||
self.collector.store_sample_domains(n, mode, variables, locations)
|
# self.collector.store_sample_domains()
|
||||||
|
# self.collector.store_sample_domains(n, mode, variables, domain)
|
||||||
def add_points(self, new_points_dict):
|
|
||||||
self.collector.add_points(new_points_dict)
|
|
||||||
@@ -63,17 +63,17 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
during training, there is no need to define to touch the
|
during training, there is no need to define to touch the
|
||||||
trainer dataloader, just call the method.
|
trainer dataloader, just call the method.
|
||||||
"""
|
"""
|
||||||
if not self.solver.problem.collector.full:
|
if not self.solver.problem.are_all_domains_discretised:
|
||||||
error_message = '\n'.join([
|
error_message = '\n'.join([
|
||||||
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
f"""{" " * 13} ---> Domain {key} {"sampled" if key in self.solver.problem.discretised_domains else
|
||||||
"not sampled"}""" for key, value in
|
"not sampled"}""" for key in
|
||||||
self._solver.problem.collector._is_conditions_ready.items()
|
self.solver.problem.domains.keys()
|
||||||
])
|
])
|
||||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||||
'are sampled. The Trainer got the following:\n'
|
'are sampled. The Trainer got the following:\n'
|
||||||
f'{error_message}')
|
f'{error_message}')
|
||||||
automatic_batching = False
|
automatic_batching = False
|
||||||
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
|
self.data_module = PinaDataModule(self.solver.problem,
|
||||||
train_size=self.train_size,
|
train_size=self.train_size,
|
||||||
test_size=self.test_size,
|
test_size=self.test_size,
|
||||||
val_size=self.val_size,
|
val_size=self.val_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user