From effd1e83bb487a7e26989b43f2a7aac0a6ec710f Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Thu, 6 Feb 2025 14:28:17 +0100 Subject: [PATCH] clean logic, fix problems for tutorial1 --- pina/collector.py | 90 ++++++++++++++++++-------------- pina/data/data_module.py | 9 ++-- pina/data/dataset.py | 1 + pina/problem/abstract_problem.py | 71 +++++++++++++++---------- pina/trainer.py | 10 ++-- 5 files changed, 105 insertions(+), 76 deletions(-) diff --git a/pina/collector.py b/pina/collector.py index 1f0fb41..7e73525 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -62,46 +62,58 @@ class Collector: # condition now is ready self._is_conditions_ready[condition_name] = True - def store_sample_domains(self, n, mode, variables, sample_locations): - # loop over all locations - for loc in sample_locations: - # get condition - condition = self.problem.conditions[loc] - condition_domain = condition.domain - if isinstance(condition_domain, str): - condition_domain = self.problem.domains[condition_domain] - keys = ["input_points", "equation"] - # if the condition is not ready, we get and store the data - if not self._is_conditions_ready[loc]: - # if it is the first time we sample - if not self.data_collections[loc]: - already_sampled = [] - # if we have sampled the condition but not all variables - else: - already_sampled = [ - self.data_collections[loc]['input_points'] - ] - # if the condition is ready but we want to sample again - else: - self._is_conditions_ready[loc] = False - already_sampled = [] + def store_sample_domains(self): + """ + Add + """ + for condition_name in self.problem.conditions: + condition = self.problem.conditions[condition_name] + if not hasattr(condition, "domain"): + continue - # get the samples - samples = [ - condition_domain.sample(n=n, mode=mode, - variables=variables) - ] + already_sampled - pts = merge_tensors(samples) - if set(pts.labels).issubset(sorted(self.problem.input_variables)): - pts = pts.sort_labels() - if sorted(pts.labels) == sorted(self.problem.input_variables): - self._is_conditions_ready[loc] = True - values = [pts, condition.equation] - self.data_collections[loc] = dict(zip(keys, values)) - else: - raise RuntimeError( - 'Try to sample variables which are not in problem defined ' - 'in the problem') + samples = self.problem.discretised_domains[condition.domain] + + self.data_collections[condition_name] = { + 'input_points': samples, + 'equation': condition.equation + } + + # # get condition + # condition = self.problem.conditions[loc] + # condition_domain = condition.domain + # if isinstance(condition_domain, str): + # condition_domain = self.problem.domains[condition_domain] + # keys = ["input_points", "equation"] + # # if the condition is not ready, we get and store the data + # if not self._is_conditions_ready[loc]: + # # if it is the first time we sample + # if not self.data_collections[loc]: + # already_sampled = [] + # # if we have sampled the condition but not all variables + # else: + # already_sampled = [ + # self.data_collections[loc]['input_points'] + # ] + # # if the condition is ready but we want to sample again + # else: + # self._is_conditions_ready[loc] = False + # already_sampled = [] + # # get the samples + # samples = [ + # condition_domain.sample(n=n, mode=mode, + # variables=variables) + # ] + already_sampled + # pts = merge_tensors(samples) + # if set(pts.labels).issubset(sorted(self.problem.input_variables)): + # pts = pts.sort_labels() + # if sorted(pts.labels) == sorted(self.problem.input_variables): + # self._is_conditions_ready[loc] = True + # values = [pts, condition.equation] + # self.data_collections[loc] = dict(zip(keys, values)) + # else: + # raise RuntimeError( + # 'Try to sample variables which are not in problem defined ' + # 'in the problem') def add_points(self, new_points_dict): """ diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 56473f8..3de2b11 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \ RandomSampler from torch.utils.data.distributed import DistributedSampler from .dataset import PinaDatasetFactory - +from ..collector import Collector class DummyDataloader: def __init__(self, dataset, device): @@ -87,7 +87,7 @@ class PinaDataModule(LightningDataModule): """ def __init__(self, - collector, + problem, train_size=.7, test_size=.2, val_size=.1, @@ -99,7 +99,6 @@ class PinaDataModule(LightningDataModule): ): """ Initialize the object, creating dataset based on input problem - :param Collector collector: PINA problem :param train_size: number/percentage of elements in train split :param test_size: number/percentage of elements in test split :param val_size: number/percentage of elements in evaluation split @@ -135,6 +134,10 @@ class PinaDataModule(LightningDataModule): self.predict_dataset = None else: self.predict_dataloader = super().predict_dataloader + + collector = Collector(problem) + collector.store_fixed_data() + collector.store_sample_domains() self.collector_splits = self._create_splits(collector, splits_dict) self.transfer_batch_to_device = self._transfer_batch_to_device diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 4eeb20e..879bb96 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -58,6 +58,7 @@ class PinaTensorDataset(PinaDataset): def __init__(self, conditions_dict, max_conditions_lengths, automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) + if automatic_batching: self._getitem_func = self._getitem_int else: diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 5f424cf..164e269 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -20,8 +20,8 @@ class AbstractProblem(metaclass=ABCMeta): def __init__(self): + self.discretised_domains = {} # create collector to manage problem data - self._collector = Collector(self) # create hook conditions <-> problems for condition_name in self.conditions: @@ -32,12 +32,12 @@ class AbstractProblem(metaclass=ABCMeta): # sampling locations). To check that all data points are ready for # training all type self.collector.full, which returns true if all # points are ready. - self.collector.store_fixed_data() + # self.collector.store_fixed_data() self._batching_dimension = 0 - @property - def collector(self): - return self._collector + # @property + # def collector(self): + # return self._collector @property def batching_dimension(self): @@ -74,6 +74,21 @@ class AbstractProblem(metaclass=ABCMeta): setattr(result, k, deepcopy(v, memo)) return result + @property + def are_all_domains_discretised(self): + """ + Check if all the domains are discretised. + + :return: True if all the domains are discretised, False otherwise + :rtype: bool + """ + return all( + [ + domain in self.discretised_domains + for domain in self.domains.keys() + ] + ) + @property def input_variables(self): """ @@ -120,7 +135,7 @@ class AbstractProblem(metaclass=ABCMeta): n, mode="random", variables="all", - locations="all"): + domains="all"): """ Generate a set of points to span the `Location` of all the conditions of the problem. @@ -134,12 +149,12 @@ class AbstractProblem(metaclass=ABCMeta): chebyshev sampling, ``chebyshev``; grid sampling ``grid``. :param variables: problem's variables to be sampled, defaults to 'all'. :type variables: str | list[str] - :param locations: problem's locations from where to sample, defaults to 'all'. + :param domain: problem's domain from where to sample, defaults to 'all'. :type locations: str :Example: >>> pinn.discretise_domain(n=10, mode='grid') - >>> pinn.discretise_domain(n=10, mode='grid', location=['bound1']) + >>> pinn.discretise_domain(n=10, mode='grid', domain=['bound1']) >>> pinn.discretise_domain(n=10, mode='grid', variables=['x']) .. warning:: @@ -154,11 +169,11 @@ class AbstractProblem(metaclass=ABCMeta): check_consistency(n, int) check_consistency(mode, str) check_consistency(variables, str) - check_consistency(locations, str) + check_consistency(domains, (list, str)) # check correct sampling mode - if mode not in DomainInterface.available_sampling_modes: - raise TypeError(f"mode {mode} not valid.") + # if mode not in DomainInterface.available_sampling_modes: + # raise TypeError(f"mode {mode} not valid.") # check correct variables if variables == "all": @@ -170,23 +185,21 @@ class AbstractProblem(metaclass=ABCMeta): f"should be in {self.input_variables}.", ) # check correct location - if locations == "all": - locations = [ - name for name in self.conditions.keys() - if isinstance(self.conditions[name], DomainEquationCondition) - ] - else: - if not isinstance(locations, (list)): - locations = [locations] - for loc in locations: - if not isinstance(self.conditions[loc], DomainEquationCondition): - raise TypeError( - f"Wrong locations passed, locations for sampling " - f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.", - ) + if domains == "all": + domains = self.domains.keys() + elif not isinstance(domains, (list)): + domains = [domains] + + for domain in domains: + self.discretised_domains[domain] = ( + self.domains[domain].sample(n, mode, variables) + ) + # if not isinstance(self.conditions[loc], DomainEquationCondition): + # raise TypeError( + # f"Wrong locations passed, locations for sampling " + # f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.", + # ) # store data - self.collector.store_sample_domains(n, mode, variables, locations) - - def add_points(self, new_points_dict): - self.collector.add_points(new_points_dict) + # self.collector.store_sample_domains() + # self.collector.store_sample_domains(n, mode, variables, domain) \ No newline at end of file diff --git a/pina/trainer.py b/pina/trainer.py index 6a16248..759c1b3 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -63,17 +63,17 @@ class Trainer(lightning.pytorch.Trainer): during training, there is no need to define to touch the trainer dataloader, just call the method. """ - if not self.solver.problem.collector.full: + if not self.solver.problem.are_all_domains_discretised: error_message = '\n'.join([ - f"""{" " * 13} ---> Condition {key} {"sampled" if value else - "not sampled"}""" for key, value in - self._solver.problem.collector._is_conditions_ready.items() + f"""{" " * 13} ---> Domain {key} {"sampled" if key in self.solver.problem.discretised_domains else + "not sampled"}""" for key in + self.solver.problem.domains.keys() ]) raise RuntimeError('Cannot create Trainer if not all conditions ' 'are sampled. The Trainer got the following:\n' f'{error_message}') automatic_batching = False - self.data_module = PinaDataModule(collector=self.solver.problem.collector, + self.data_module = PinaDataModule(self.solver.problem, train_size=self.train_size, test_size=self.test_size, val_size=self.val_size,