minor changes/ trainer update

This commit is contained in:
Dario Coscia
2024-10-10 19:24:46 +02:00
committed by Nicola Demo
parent 7528f6ef74
commit b9753c34b2
8 changed files with 69 additions and 46 deletions

View File

@@ -5,21 +5,35 @@ from .utils import check_consistency, merge_tensors
class Collector: class Collector:
def __init__(self, problem): def __init__(self, problem):
self.problem = problem # hook Collector <-> Problem # creating a hook between collector and problem
self.data_collections = {name : {} for name in self.problem.conditions} # collection of data self.problem = problem
self.is_conditions_ready = {
name : False for name in self.problem.conditions} # names of the conditions that need to be sampled # this variable is used to store the data in the form:
self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig # {'[condition_name]' :
# {'input_points' : Tensor,
# '[equation/output_points/conditional_variables]': Tensor}
# }
# those variables are used for the dataloading
self._data_collections = {name : {} for name in self.problem.conditions}
# variables used to check that all conditions are sampled
self._is_conditions_ready = {
name : False for name in self.problem.conditions}
self.full = False
@property @property
def full(self): def full(self):
return all(self.is_conditions_ready.values()) return all(self._is_conditions_ready.values())
@full.setter @full.setter
def full(self, value): def full(self, value):
check_consistency(value, bool) check_consistency(value, bool)
self._full = value self._full = value
@property
def data_collections(self):
return self._data_collections
@property @property
def problem(self): def problem(self):
return self._problem return self._problem
@@ -33,13 +47,13 @@ class Collector:
for condition_name, condition in self.problem.conditions.items(): for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute # if the condition is not ready and domain is not attribute
# of condition, we get and store the data # of condition, we get and store the data
if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")): if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
# get data # get data
keys = condition.__slots__ keys = condition.__slots__
values = [getattr(condition, name) for name in keys] values = [getattr(condition, name) for name in keys]
self.data_collections[condition_name] = dict(zip(keys, values)) self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready # condition now is ready
self.is_conditions_ready[condition_name] = True self._is_conditions_ready[condition_name] = True
def store_sample_domains(self, n, mode, variables, sample_locations): def store_sample_domains(self, n, mode, variables, sample_locations):
# loop over all locations # loop over all locations
@@ -48,7 +62,7 @@ class Collector:
condition = self.problem.conditions[loc] condition = self.problem.conditions[loc]
keys = ["input_points", "equation"] keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data # if the condition is not ready, we get and store the data
if (not self.is_conditions_ready[loc]): if (not self._is_conditions_ready[loc]):
# if it is the first time we sample # if it is the first time we sample
if not self.data_collections[loc]: if not self.data_collections[loc]:
already_sampled = [] already_sampled = []
@@ -57,7 +71,7 @@ class Collector:
already_sampled = [self.data_collections[loc]['input_points']] already_sampled = [self.data_collections[loc]['input_points']]
# if the condition is ready but we want to sample again # if the condition is ready but we want to sample again
else: else:
self.is_conditions_ready[loc] = False self._is_conditions_ready[loc] = False
already_sampled = [] already_sampled = []
# get the samples # get the samples
@@ -70,7 +84,7 @@ class Collector:
): ):
pts = pts.sort_labels() pts = pts.sort_labels()
if sorted(pts.labels)==sorted(self.problem.input_variables): if sorted(pts.labels)==sorted(self.problem.input_variables):
self.is_conditions_ready[loc] = True self._is_conditions_ready[loc] = True
values = [pts, condition.equation] values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values)) self.data_collections[loc] = dict(zip(keys, values))
else: else:
@@ -84,6 +98,6 @@ class Collector:
:raises RuntimeError: if at least one condition is not already sampled :raises RuntimeError: if at least one condition is not already sampled
""" """
for k,v in new_points_dict.items(): for k,v in new_points_dict.items():
if not self.is_conditions_ready[k]: if not self._is_conditions_ready[k]:
raise RuntimeError('Cannot add points on a non sampled condition') raise RuntimeError('Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v) self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)

View File

@@ -13,20 +13,20 @@ class DataConditionInterface(ConditionInterface):
distribution distribution
""" """
__slots__ = ["data", "conditionalvariable"] __slots__ = ["input_points", "conditional_variables"]
def __init__(self, data, conditionalvariable=None): def __init__(self, input_points, conditional_variables=None):
""" """
TODO TODO
""" """
super().__init__() super().__init__()
self.data = data self.input_points = input_points
self.conditionalvariable = conditionalvariable self.conditional_variables = conditional_variables
self.condition_type = 'unsupervised' self.condition_type = 'unsupervised'
def __setattr__(self, key, value): def __setattr__(self, key, value):
if (key == 'data') or (key == 'conditionalvariable'): if (key == 'input_points') or (key == 'conditional_variables'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor)) check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value) DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): elif key in ('problem', 'condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -29,5 +29,5 @@ class DomainEquationCondition(ConditionInterface):
elif key == 'equation': elif key == 'equation':
check_consistency(value, (EquationInterface)) check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value) DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): elif key in ('problem', 'condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -30,5 +30,5 @@ class InputPointsEquationCondition(ConditionInterface):
elif key == 'equation': elif key == 'equation':
check_consistency(value, (EquationInterface)) check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value) InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): elif key in ('problem', 'condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -27,5 +27,5 @@ class InputOutputPointsCondition(ConditionInterface):
if (key == 'input_points') or (key == 'output_points'): if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor)) check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value) InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): elif key in ('problem', 'condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -20,7 +20,7 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self): def __init__(self):
# create collector to manage problem data # create collector to manage problem data
self.collector = Collector(self) self._collector = Collector(self)
# create hook conditions <-> problems # create hook conditions <-> problems
for condition_name in self.conditions: for condition_name in self.conditions:
@@ -33,7 +33,12 @@ class AbstractProblem(metaclass=ABCMeta):
# points are ready. # points are ready.
self.collector.store_fixed_data() self.collector.store_fixed_data()
@property
def collector(self):
return self._collector
# TODO this should be erase when dataloading will interface collector,
# kept only for back compatibility
@property @property
def input_pts(self): def input_pts(self):
to_return = {} to_return = {}
@@ -42,10 +47,6 @@ class AbstractProblem(metaclass=ABCMeta):
to_return[k] = v['input_points'] to_return[k] = v['input_points']
return to_return return to_return
@property
def _have_sampled_points(self):
return self.collector.is_conditions_ready
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
""" """
Implements deepcopy for the Implements deepcopy for the
@@ -160,7 +161,9 @@ class AbstractProblem(metaclass=ABCMeta):
# check correct location # check correct location
if locations == "all": if locations == "all":
locations = [name for name in self.conditions.keys()] locations = [name for name in self.conditions.keys()
if isinstance(self.conditions[name],
DomainEquationCondition)]
else: else:
if not isinstance(locations, (list)): if not isinstance(locations, (list)):
locations = [locations] locations = [locations]
@@ -168,7 +171,7 @@ class AbstractProblem(metaclass=ABCMeta):
if not isinstance(self.conditions[loc], DomainEquationCondition): if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError( raise TypeError(
f"Wrong locations passed, locations for sampling " f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.", f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
) )
# store data # store data

View File

@@ -32,29 +32,18 @@ class Trainer(pytorch_lightning.Trainer):
if batch_size is not None: if batch_size is not None:
check_consistency(batch_size, int) check_consistency(batch_size, int)
self._model = solver self.solver = solver
self.batch_size = batch_size self.batch_size = batch_size
self._create_loader() self._create_loader()
self._move_to_device() self._move_to_device()
# create dataloader
# if solver.problem.have_sampled_points is False:
# raise RuntimeError(
# f"Input points in {solver.problem.not_sampled_points} "
# "training are None. Please "
# "sample points in your problem by calling "
# "discretise_domain function before train "
# "in the provided locations."
# )
# self._create_or_update_loader()
def _move_to_device(self): def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0] device = self._accelerator_connector._parallel_devices[0]
# move parameters to device # move parameters to device
pb = self._model.problem pb = self.solver.problem
if hasattr(pb, "unknown_parameters"): if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters: for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter( pb.unknown_parameters[key] = torch.nn.Parameter(
@@ -67,14 +56,21 @@ class Trainer(pytorch_lightning.Trainer):
during training, there is no need to define to touch the during training, there is no need to define to touch the
trainer dataloader, just call the method. trainer dataloader, just call the method.
""" """
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
devices = self._accelerator_connector._parallel_devices devices = self._accelerator_connector._parallel_devices
if len(devices) > 1: if len(devices) > 1:
raise RuntimeError("Parallel training is not supported yet.") raise RuntimeError("Parallel training is not supported yet.")
device = devices[0] device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device) dataset_phys = SamplePointDataset(self.solver.problem, device)
dataset_data = DataPointDataset(self._model.problem, device) dataset_data = DataPointDataset(self.solver.problem, device)
self._loader = SamplePointLoader( self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
) )
@@ -84,7 +80,7 @@ class Trainer(pytorch_lightning.Trainer):
Train the solver method. Train the solver method.
""" """
return super().fit( return super().fit(
self._model, train_dataloaders=self._loader, **kwargs self.solver, train_dataloaders=self._loader, **kwargs
) )
@property @property
@@ -92,4 +88,4 @@ class Trainer(pytorch_lightning.Trainer):
""" """
Returning trainer solver. Returning trainer solver.
""" """
return self._model return self._solver

View File

@@ -89,6 +89,7 @@ def test_discretise_domain():
poisson_problem.discretise_domain(n, 'lh', locations=['D']) poisson_problem.discretise_domain(n, 'lh', locations=['D'])
assert poisson_problem.input_pts['D'].shape[0] == n assert poisson_problem.input_pts['D'].shape[0] == n
poisson_problem.discretise_domain(n)
def test_sampling_few_variables(): def test_sampling_few_variables():
n = 10 n = 10
@@ -98,7 +99,7 @@ def test_sampling_few_variables():
locations=['D'], locations=['D'],
variables=['x']) variables=['x'])
assert poisson_problem.input_pts['D'].shape[1] == 1 assert poisson_problem.input_pts['D'].shape[1] == 1
assert poisson_problem._have_sampled_points['D'] is False assert poisson_problem.collector._is_conditions_ready['D'] is False
def test_variables_correct_order_sampling(): def test_variables_correct_order_sampling():
@@ -140,3 +141,12 @@ def test_add_points():
poisson_problem.add_points({'D': new_pts}) poisson_problem.add_points({'D': new_pts})
assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x')) assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x'))
assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y')) assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y'))
def test_collector():
poisson_problem = Poisson()
collector = poisson_problem.collector
assert collector.full is False
assert collector._is_conditions_ready['data'] is True
poisson_problem.discretise_domain(10)
assert collector.full is True