clean logic, fix problems for tutorial1

This commit is contained in:
Nicola Demo
2025-02-06 14:28:17 +01:00
parent 7702427e8d
commit effd1e83bb
5 changed files with 105 additions and 76 deletions

View File

@@ -62,46 +62,58 @@ class Collector:
# condition now is ready
self._is_conditions_ready[condition_name] = True
def store_sample_domains(self, n, mode, variables, sample_locations):
# loop over all locations
for loc in sample_locations:
# get condition
condition = self.problem.conditions[loc]
condition_domain = condition.domain
if isinstance(condition_domain, str):
condition_domain = self.problem.domains[condition_domain]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
# if we have sampled the condition but not all variables
else:
already_sampled = [
self.data_collections[loc]['input_points']
]
# if the condition is ready but we want to sample again
else:
self._is_conditions_ready[loc] = False
already_sampled = []
def store_sample_domains(self):
"""
Add
"""
for condition_name in self.problem.conditions:
condition = self.problem.conditions[condition_name]
if not hasattr(condition, "domain"):
continue
# get the samples
samples = [
condition_domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))
else:
raise RuntimeError(
'Try to sample variables which are not in problem defined '
'in the problem')
samples = self.problem.discretised_domains[condition.domain]
self.data_collections[condition_name] = {
'input_points': samples,
'equation': condition.equation
}
# # get condition
# condition = self.problem.conditions[loc]
# condition_domain = condition.domain
# if isinstance(condition_domain, str):
# condition_domain = self.problem.domains[condition_domain]
# keys = ["input_points", "equation"]
# # if the condition is not ready, we get and store the data
# if not self._is_conditions_ready[loc]:
# # if it is the first time we sample
# if not self.data_collections[loc]:
# already_sampled = []
# # if we have sampled the condition but not all variables
# else:
# already_sampled = [
# self.data_collections[loc]['input_points']
# ]
# # if the condition is ready but we want to sample again
# else:
# self._is_conditions_ready[loc] = False
# already_sampled = []
# # get the samples
# samples = [
# condition_domain.sample(n=n, mode=mode,
# variables=variables)
# ] + already_sampled
# pts = merge_tensors(samples)
# if set(pts.labels).issubset(sorted(self.problem.input_variables)):
# pts = pts.sort_labels()
# if sorted(pts.labels) == sorted(self.problem.input_variables):
# self._is_conditions_ready[loc] = True
# values = [pts, condition.equation]
# self.data_collections[loc] = dict(zip(keys, values))
# else:
# raise RuntimeError(
# 'Try to sample variables which are not in problem defined '
# 'in the problem')
def add_points(self, new_points_dict):
"""

View File

@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
RandomSampler
from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory
from ..collector import Collector
class DummyDataloader:
def __init__(self, dataset, device):
@@ -87,7 +87,7 @@ class PinaDataModule(LightningDataModule):
"""
def __init__(self,
collector,
problem,
train_size=.7,
test_size=.2,
val_size=.1,
@@ -99,7 +99,6 @@ class PinaDataModule(LightningDataModule):
):
"""
Initialize the object, creating dataset based on input problem
:param Collector collector: PINA problem
:param train_size: number/percentage of elements in train split
:param test_size: number/percentage of elements in test split
:param val_size: number/percentage of elements in evaluation split
@@ -135,6 +134,10 @@ class PinaDataModule(LightningDataModule):
self.predict_dataset = None
else:
self.predict_dataloader = super().predict_dataloader
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device

View File

@@ -58,6 +58,7 @@ class PinaTensorDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching:
self._getitem_func = self._getitem_int
else:

View File

@@ -20,8 +20,8 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
self.discretised_domains = {}
# create collector to manage problem data
self._collector = Collector(self)
# create hook conditions <-> problems
for condition_name in self.conditions:
@@ -32,12 +32,12 @@ class AbstractProblem(metaclass=ABCMeta):
# sampling locations). To check that all data points are ready for
# training all type self.collector.full, which returns true if all
# points are ready.
self.collector.store_fixed_data()
# self.collector.store_fixed_data()
self._batching_dimension = 0
@property
def collector(self):
return self._collector
# @property
# def collector(self):
# return self._collector
@property
def batching_dimension(self):
@@ -74,6 +74,21 @@ class AbstractProblem(metaclass=ABCMeta):
setattr(result, k, deepcopy(v, memo))
return result
@property
def are_all_domains_discretised(self):
"""
Check if all the domains are discretised.
:return: True if all the domains are discretised, False otherwise
:rtype: bool
"""
return all(
[
domain in self.discretised_domains
for domain in self.domains.keys()
]
)
@property
def input_variables(self):
"""
@@ -120,7 +135,7 @@ class AbstractProblem(metaclass=ABCMeta):
n,
mode="random",
variables="all",
locations="all"):
domains="all"):
"""
Generate a set of points to span the `Location` of all the conditions of
the problem.
@@ -134,12 +149,12 @@ class AbstractProblem(metaclass=ABCMeta):
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
:param variables: problem's variables to be sampled, defaults to 'all'.
:type variables: str | list[str]
:param locations: problem's locations from where to sample, defaults to 'all'.
:param domain: problem's domain from where to sample, defaults to 'all'.
:type locations: str
:Example:
>>> pinn.discretise_domain(n=10, mode='grid')
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1'])
>>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1'])
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
.. warning::
@@ -154,11 +169,11 @@ class AbstractProblem(metaclass=ABCMeta):
check_consistency(n, int)
check_consistency(mode, str)
check_consistency(variables, str)
check_consistency(locations, str)
check_consistency(domains, (list, str))
# check correct sampling mode
if mode not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {mode} not valid.")
# if mode not in DomainInterface.available_sampling_modes:
# raise TypeError(f"mode {mode} not valid.")
# check correct variables
if variables == "all":
@@ -170,23 +185,21 @@ class AbstractProblem(metaclass=ABCMeta):
f"should be in {self.input_variables}.",
)
# check correct location
if locations == "all":
locations = [
name for name in self.conditions.keys()
if isinstance(self.conditions[name], DomainEquationCondition)
]
else:
if not isinstance(locations, (list)):
locations = [locations]
for loc in locations:
if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError(
f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
)
if domains == "all":
domains = self.domains.keys()
elif not isinstance(domains, (list)):
domains = [domains]
for domain in domains:
self.discretised_domains[domain] = (
self.domains[domain].sample(n, mode, variables)
)
# if not isinstance(self.conditions[loc], DomainEquationCondition):
# raise TypeError(
# f"Wrong locations passed, locations for sampling "
# f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
# )
# store data
self.collector.store_sample_domains(n, mode, variables, locations)
def add_points(self, new_points_dict):
self.collector.add_points(new_points_dict)
# self.collector.store_sample_domains()
# self.collector.store_sample_domains(n, mode, variables, domain)

View File

@@ -63,17 +63,17 @@ class Trainer(lightning.pytorch.Trainer):
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
if not self.solver.problem.collector.full:
if not self.solver.problem.are_all_domains_discretised:
error_message = '\n'.join([
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}""" for key, value in
self._solver.problem.collector._is_conditions_ready.items()
f"""{" " * 13} ---> Domain {key} {"sampled" if key in self.solver.problem.discretised_domains else
"not sampled"}""" for key in
self.solver.problem.domains.keys()
])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
automatic_batching = False
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
self.data_module = PinaDataModule(self.solver.problem,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.val_size,