From b9753c34b25288fb679ea07a9cc18ef538c09e36 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Thu, 10 Oct 2024 19:24:46 +0200 Subject: [PATCH] minor changes/ trainer update --- pina/collector.py | 38 ++++++++++++++------- pina/condition/data_condition.py | 12 +++---- pina/condition/domain_equation_condition.py | 2 +- pina/condition/input_equation_condition.py | 2 +- pina/condition/input_output_condition.py | 2 +- pina/problem/abstract_problem.py | 17 +++++---- pina/trainer.py | 30 +++++++--------- tests/test_problem.py | 12 ++++++- 8 files changed, 69 insertions(+), 46 deletions(-) diff --git a/pina/collector.py b/pina/collector.py index 0f4e9da..f44c222 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -5,21 +5,35 @@ from .utils import check_consistency, merge_tensors class Collector: def __init__(self, problem): - self.problem = problem # hook Collector <-> Problem - self.data_collections = {name : {} for name in self.problem.conditions} # collection of data - self.is_conditions_ready = { - name : False for name in self.problem.conditions} # names of the conditions that need to be sampled - self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig + # creating a hook between collector and problem + self.problem = problem + + # this variable is used to store the data in the form: + # {'[condition_name]' : + # {'input_points' : Tensor, + # '[equation/output_points/conditional_variables]': Tensor} + # } + # those variables are used for the dataloading + self._data_collections = {name : {} for name in self.problem.conditions} + + # variables used to check that all conditions are sampled + self._is_conditions_ready = { + name : False for name in self.problem.conditions} + self.full = False @property def full(self): - return all(self.is_conditions_ready.values()) + return all(self._is_conditions_ready.values()) @full.setter def full(self, value): check_consistency(value, bool) self._full = value + @property + def data_collections(self): + return self._data_collections + @property def problem(self): return self._problem @@ -33,13 +47,13 @@ class Collector: for condition_name, condition in self.problem.conditions.items(): # if the condition is not ready and domain is not attribute # of condition, we get and store the data - if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")): + if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")): # get data keys = condition.__slots__ values = [getattr(condition, name) for name in keys] self.data_collections[condition_name] = dict(zip(keys, values)) # condition now is ready - self.is_conditions_ready[condition_name] = True + self._is_conditions_ready[condition_name] = True def store_sample_domains(self, n, mode, variables, sample_locations): # loop over all locations @@ -48,7 +62,7 @@ class Collector: condition = self.problem.conditions[loc] keys = ["input_points", "equation"] # if the condition is not ready, we get and store the data - if (not self.is_conditions_ready[loc]): + if (not self._is_conditions_ready[loc]): # if it is the first time we sample if not self.data_collections[loc]: already_sampled = [] @@ -57,7 +71,7 @@ class Collector: already_sampled = [self.data_collections[loc]['input_points']] # if the condition is ready but we want to sample again else: - self.is_conditions_ready[loc] = False + self._is_conditions_ready[loc] = False already_sampled = [] # get the samples @@ -70,7 +84,7 @@ class Collector: ): pts = pts.sort_labels() if sorted(pts.labels)==sorted(self.problem.input_variables): - self.is_conditions_ready[loc] = True + self._is_conditions_ready[loc] = True values = [pts, condition.equation] self.data_collections[loc] = dict(zip(keys, values)) else: @@ -84,6 +98,6 @@ class Collector: :raises RuntimeError: if at least one condition is not already sampled """ for k,v in new_points_dict.items(): - if not self.is_conditions_ready[k]: + if not self._is_conditions_ready[k]: raise RuntimeError('Cannot add points on a non sampled condition') self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v) \ No newline at end of file diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index d5ac639..90d248b 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -13,20 +13,20 @@ class DataConditionInterface(ConditionInterface): distribution """ - __slots__ = ["data", "conditionalvariable"] + __slots__ = ["input_points", "conditional_variables"] - def __init__(self, data, conditionalvariable=None): + def __init__(self, input_points, conditional_variables=None): """ TODO """ super().__init__() - self.data = data - self.conditionalvariable = conditionalvariable + self.input_points = input_points + self.conditional_variables = conditional_variables self.condition_type = 'unsupervised' def __setattr__(self, key, value): - if (key == 'data') or (key == 'conditionalvariable'): + if (key == 'input_points') or (key == 'conditional_variables'): check_consistency(value, (LabelTensor, Graph, torch.Tensor)) DataConditionInterface.__dict__[key].__set__(self, value) - elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + elif key in ('problem', 'condition_type'): super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index ab35d20..ce4c7d3 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -29,5 +29,5 @@ class DomainEquationCondition(ConditionInterface): elif key == 'equation': check_consistency(value, (EquationInterface)) DomainEquationCondition.__dict__[key].__set__(self, value) - elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + elif key in ('problem', 'condition_type'): super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index dc12d02..ac47fa2 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -30,5 +30,5 @@ class InputPointsEquationCondition(ConditionInterface): elif key == 'equation': check_consistency(value, (EquationInterface)) InputPointsEquationCondition.__dict__[key].__set__(self, value) - elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + elif key in ('problem', 'condition_type'): super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index a4fa489..f8fd46e 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -27,5 +27,5 @@ class InputOutputPointsCondition(ConditionInterface): if (key == 'input_points') or (key == 'output_points'): check_consistency(value, (LabelTensor, Graph, torch.Tensor)) InputOutputPointsCondition.__dict__[key].__set__(self, value) - elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + elif key in ('problem', 'condition_type'): super().__setattr__(key, value) diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index edf214a..600a688 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -20,7 +20,7 @@ class AbstractProblem(metaclass=ABCMeta): def __init__(self): # create collector to manage problem data - self.collector = Collector(self) + self._collector = Collector(self) # create hook conditions <-> problems for condition_name in self.conditions: @@ -33,7 +33,12 @@ class AbstractProblem(metaclass=ABCMeta): # points are ready. self.collector.store_fixed_data() + @property + def collector(self): + return self._collector + # TODO this should be erase when dataloading will interface collector, + # kept only for back compatibility @property def input_pts(self): to_return = {} @@ -41,10 +46,6 @@ class AbstractProblem(metaclass=ABCMeta): if 'input_points' in v.keys(): to_return[k] = v['input_points'] return to_return - - @property - def _have_sampled_points(self): - return self.collector.is_conditions_ready def __deepcopy__(self, memo): """ @@ -160,7 +161,9 @@ class AbstractProblem(metaclass=ABCMeta): # check correct location if locations == "all": - locations = [name for name in self.conditions.keys()] + locations = [name for name in self.conditions.keys() + if isinstance(self.conditions[name], + DomainEquationCondition)] else: if not isinstance(locations, (list)): locations = [locations] @@ -168,7 +171,7 @@ class AbstractProblem(metaclass=ABCMeta): if not isinstance(self.conditions[loc], DomainEquationCondition): raise TypeError( f"Wrong locations passed, locations for sampling " - f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.", + f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.", ) # store data diff --git a/pina/trainer.py b/pina/trainer.py index 758bbaa..ba18f33 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -32,29 +32,18 @@ class Trainer(pytorch_lightning.Trainer): if batch_size is not None: check_consistency(batch_size, int) - self._model = solver + self.solver = solver self.batch_size = batch_size self._create_loader() self._move_to_device() - # create dataloader - # if solver.problem.have_sampled_points is False: - # raise RuntimeError( - # f"Input points in {solver.problem.not_sampled_points} " - # "training are None. Please " - # "sample points in your problem by calling " - # "discretise_domain function before train " - # "in the provided locations." - # ) - - # self._create_or_update_loader() def _move_to_device(self): device = self._accelerator_connector._parallel_devices[0] # move parameters to device - pb = self._model.problem + pb = self.solver.problem if hasattr(pb, "unknown_parameters"): for key in pb.unknown_parameters: pb.unknown_parameters[key] = torch.nn.Parameter( @@ -67,14 +56,21 @@ class Trainer(pytorch_lightning.Trainer): during training, there is no need to define to touch the trainer dataloader, just call the method. """ + if not self.solver.problem.collector.full: + error_message = '\n'.join( + [f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}' + for key, value in self.solver.problem.collector._is_conditions_ready.items()]) + raise RuntimeError('Cannot create Trainer if not all conditions ' + 'are sampled. The Trainer got the following:\n' + f'{error_message}') devices = self._accelerator_connector._parallel_devices if len(devices) > 1: raise RuntimeError("Parallel training is not supported yet.") device = devices[0] - dataset_phys = SamplePointDataset(self._model.problem, device) - dataset_data = DataPointDataset(self._model.problem, device) + dataset_phys = SamplePointDataset(self.solver.problem, device) + dataset_data = DataPointDataset(self.solver.problem, device) self._loader = SamplePointLoader( dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True ) @@ -84,7 +80,7 @@ class Trainer(pytorch_lightning.Trainer): Train the solver method. """ return super().fit( - self._model, train_dataloaders=self._loader, **kwargs + self.solver, train_dataloaders=self._loader, **kwargs ) @property @@ -92,4 +88,4 @@ class Trainer(pytorch_lightning.Trainer): """ Returning trainer solver. """ - return self._model + return self._solver diff --git a/tests/test_problem.py b/tests/test_problem.py index cc7e255..6d37596 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -89,6 +89,7 @@ def test_discretise_domain(): poisson_problem.discretise_domain(n, 'lh', locations=['D']) assert poisson_problem.input_pts['D'].shape[0] == n + poisson_problem.discretise_domain(n) def test_sampling_few_variables(): n = 10 @@ -98,7 +99,7 @@ def test_sampling_few_variables(): locations=['D'], variables=['x']) assert poisson_problem.input_pts['D'].shape[1] == 1 - assert poisson_problem._have_sampled_points['D'] is False + assert poisson_problem.collector._is_conditions_ready['D'] is False def test_variables_correct_order_sampling(): @@ -140,3 +141,12 @@ def test_add_points(): poisson_problem.add_points({'D': new_pts}) assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x')) assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y')) + + +def test_collector(): + poisson_problem = Poisson() + collector = poisson_problem.collector + assert collector.full is False + assert collector._is_conditions_ready['data'] is True + poisson_problem.discretise_domain(10) + assert collector.full is True