clean logic, fix problems for tutorial1

This commit is contained in:
Nicola Demo
2025-02-06 14:28:17 +01:00
parent 7702427e8d
commit effd1e83bb
5 changed files with 105 additions and 76 deletions

View File

@@ -62,46 +62,58 @@ class Collector:
# condition now is ready # condition now is ready
self._is_conditions_ready[condition_name] = True self._is_conditions_ready[condition_name] = True
def store_sample_domains(self, n, mode, variables, sample_locations): def store_sample_domains(self):
# loop over all locations """
for loc in sample_locations: Add
# get condition """
condition = self.problem.conditions[loc] for condition_name in self.problem.conditions:
condition_domain = condition.domain condition = self.problem.conditions[condition_name]
if isinstance(condition_domain, str): if not hasattr(condition, "domain"):
condition_domain = self.problem.domains[condition_domain] continue
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
# if we have sampled the condition but not all variables
else:
already_sampled = [
self.data_collections[loc]['input_points']
]
# if the condition is ready but we want to sample again
else:
self._is_conditions_ready[loc] = False
already_sampled = []
# get the samples samples = self.problem.discretised_domains[condition.domain]
samples = [
condition_domain.sample(n=n, mode=mode, self.data_collections[condition_name] = {
variables=variables) 'input_points': samples,
] + already_sampled 'equation': condition.equation
pts = merge_tensors(samples) }
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels() # # get condition
if sorted(pts.labels) == sorted(self.problem.input_variables): # condition = self.problem.conditions[loc]
self._is_conditions_ready[loc] = True # condition_domain = condition.domain
values = [pts, condition.equation] # if isinstance(condition_domain, str):
self.data_collections[loc] = dict(zip(keys, values)) # condition_domain = self.problem.domains[condition_domain]
else: # keys = ["input_points", "equation"]
raise RuntimeError( # # if the condition is not ready, we get and store the data
'Try to sample variables which are not in problem defined ' # if not self._is_conditions_ready[loc]:
'in the problem') # # if it is the first time we sample
# if not self.data_collections[loc]:
# already_sampled = []
# # if we have sampled the condition but not all variables
# else:
# already_sampled = [
# self.data_collections[loc]['input_points']
# ]
# # if the condition is ready but we want to sample again
# else:
# self._is_conditions_ready[loc] = False
# already_sampled = []
# # get the samples
# samples = [
# condition_domain.sample(n=n, mode=mode,
# variables=variables)
# ] + already_sampled
# pts = merge_tensors(samples)
# if set(pts.labels).issubset(sorted(self.problem.input_variables)):
# pts = pts.sort_labels()
# if sorted(pts.labels) == sorted(self.problem.input_variables):
# self._is_conditions_ready[loc] = True
# values = [pts, condition.equation]
# self.data_collections[loc] = dict(zip(keys, values))
# else:
# raise RuntimeError(
# 'Try to sample variables which are not in problem defined '
# 'in the problem')
def add_points(self, new_points_dict): def add_points(self, new_points_dict):
""" """

View File

@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
RandomSampler RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory from .dataset import PinaDatasetFactory
from ..collector import Collector
class DummyDataloader: class DummyDataloader:
def __init__(self, dataset, device): def __init__(self, dataset, device):
@@ -87,7 +87,7 @@ class PinaDataModule(LightningDataModule):
""" """
def __init__(self, def __init__(self,
collector, problem,
train_size=.7, train_size=.7,
test_size=.2, test_size=.2,
val_size=.1, val_size=.1,
@@ -99,7 +99,6 @@ class PinaDataModule(LightningDataModule):
): ):
""" """
Initialize the object, creating dataset based on input problem Initialize the object, creating dataset based on input problem
:param Collector collector: PINA problem
:param train_size: number/percentage of elements in train split :param train_size: number/percentage of elements in train split
:param test_size: number/percentage of elements in test split :param test_size: number/percentage of elements in test split
:param val_size: number/percentage of elements in evaluation split :param val_size: number/percentage of elements in evaluation split
@@ -135,6 +134,10 @@ class PinaDataModule(LightningDataModule):
self.predict_dataset = None self.predict_dataset = None
else: else:
self.predict_dataloader = super().predict_dataloader self.predict_dataloader = super().predict_dataloader
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
self.collector_splits = self._create_splits(collector, splits_dict) self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device self.transfer_batch_to_device = self._transfer_batch_to_device

View File

@@ -58,6 +58,7 @@ class PinaTensorDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths, def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching): automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths) super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching: if automatic_batching:
self._getitem_func = self._getitem_int self._getitem_func = self._getitem_int
else: else:

View File

@@ -20,8 +20,8 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self): def __init__(self):
self.discretised_domains = {}
# create collector to manage problem data # create collector to manage problem data
self._collector = Collector(self)
# create hook conditions <-> problems # create hook conditions <-> problems
for condition_name in self.conditions: for condition_name in self.conditions:
@@ -32,12 +32,12 @@ class AbstractProblem(metaclass=ABCMeta):
# sampling locations). To check that all data points are ready for # sampling locations). To check that all data points are ready for
# training all type self.collector.full, which returns true if all # training all type self.collector.full, which returns true if all
# points are ready. # points are ready.
self.collector.store_fixed_data() # self.collector.store_fixed_data()
self._batching_dimension = 0 self._batching_dimension = 0
@property # @property
def collector(self): # def collector(self):
return self._collector # return self._collector
@property @property
def batching_dimension(self): def batching_dimension(self):
@@ -74,6 +74,21 @@ class AbstractProblem(metaclass=ABCMeta):
setattr(result, k, deepcopy(v, memo)) setattr(result, k, deepcopy(v, memo))
return result return result
@property
def are_all_domains_discretised(self):
"""
Check if all the domains are discretised.
:return: True if all the domains are discretised, False otherwise
:rtype: bool
"""
return all(
[
domain in self.discretised_domains
for domain in self.domains.keys()
]
)
@property @property
def input_variables(self): def input_variables(self):
""" """
@@ -120,7 +135,7 @@ class AbstractProblem(metaclass=ABCMeta):
n, n,
mode="random", mode="random",
variables="all", variables="all",
locations="all"): domains="all"):
""" """
Generate a set of points to span the `Location` of all the conditions of Generate a set of points to span the `Location` of all the conditions of
the problem. the problem.
@@ -134,12 +149,12 @@ class AbstractProblem(metaclass=ABCMeta):
chebyshev sampling, ``chebyshev``; grid sampling ``grid``. chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
:param variables: problem's variables to be sampled, defaults to 'all'. :param variables: problem's variables to be sampled, defaults to 'all'.
:type variables: str | list[str] :type variables: str | list[str]
:param locations: problem's locations from where to sample, defaults to 'all'. :param domain: problem's domain from where to sample, defaults to 'all'.
:type locations: str :type locations: str
:Example: :Example:
>>> pinn.discretise_domain(n=10, mode='grid') >>> pinn.discretise_domain(n=10, mode='grid')
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1']) >>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1'])
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x']) >>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
.. warning:: .. warning::
@@ -154,11 +169,11 @@ class AbstractProblem(metaclass=ABCMeta):
check_consistency(n, int) check_consistency(n, int)
check_consistency(mode, str) check_consistency(mode, str)
check_consistency(variables, str) check_consistency(variables, str)
check_consistency(locations, str) check_consistency(domains, (list, str))
# check correct sampling mode # check correct sampling mode
if mode not in DomainInterface.available_sampling_modes: # if mode not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {mode} not valid.") # raise TypeError(f"mode {mode} not valid.")
# check correct variables # check correct variables
if variables == "all": if variables == "all":
@@ -170,23 +185,21 @@ class AbstractProblem(metaclass=ABCMeta):
f"should be in {self.input_variables}.", f"should be in {self.input_variables}.",
) )
# check correct location # check correct location
if locations == "all": if domains == "all":
locations = [ domains = self.domains.keys()
name for name in self.conditions.keys() elif not isinstance(domains, (list)):
if isinstance(self.conditions[name], DomainEquationCondition) domains = [domains]
]
else: for domain in domains:
if not isinstance(locations, (list)): self.discretised_domains[domain] = (
locations = [locations] self.domains[domain].sample(n, mode, variables)
for loc in locations: )
if not isinstance(self.conditions[loc], DomainEquationCondition): # if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError( # raise TypeError(
f"Wrong locations passed, locations for sampling " # f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.", # f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
) # )
# store data # store data
self.collector.store_sample_domains(n, mode, variables, locations) # self.collector.store_sample_domains()
# self.collector.store_sample_domains(n, mode, variables, domain)
def add_points(self, new_points_dict):
self.collector.add_points(new_points_dict)

View File

@@ -63,17 +63,17 @@ class Trainer(lightning.pytorch.Trainer):
during training, there is no need to define to touch the during training, there is no need to define to touch the
trainer dataloader, just call the method. trainer dataloader, just call the method.
""" """
if not self.solver.problem.collector.full: if not self.solver.problem.are_all_domains_discretised:
error_message = '\n'.join([ error_message = '\n'.join([
f"""{" " * 13} ---> Condition {key} {"sampled" if value else f"""{" " * 13} ---> Domain {key} {"sampled" if key in self.solver.problem.discretised_domains else
"not sampled"}""" for key, value in "not sampled"}""" for key in
self._solver.problem.collector._is_conditions_ready.items() self.solver.problem.domains.keys()
]) ])
raise RuntimeError('Cannot create Trainer if not all conditions ' raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n' 'are sampled. The Trainer got the following:\n'
f'{error_message}') f'{error_message}')
automatic_batching = False automatic_batching = False
self.data_module = PinaDataModule(collector=self.solver.problem.collector, self.data_module = PinaDataModule(self.solver.problem,
train_size=self.train_size, train_size=self.train_size,
test_size=self.test_size, test_size=self.test_size,
val_size=self.val_size, val_size=self.val_size,