minor changes/ trainer update

This commit is contained in:
Dario Coscia
2024-10-10 19:24:46 +02:00
committed by Nicola Demo
parent 7528f6ef74
commit b9753c34b2
8 changed files with 69 additions and 46 deletions

View File

@@ -5,21 +5,35 @@ from .utils import check_consistency, merge_tensors
class Collector:
def __init__(self, problem):
self.problem = problem # hook Collector <-> Problem
self.data_collections = {name : {} for name in self.problem.conditions} # collection of data
self.is_conditions_ready = {
name : False for name in self.problem.conditions} # names of the conditions that need to be sampled
self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig
# creating a hook between collector and problem
self.problem = problem
# this variable is used to store the data in the form:
# {'[condition_name]' :
# {'input_points' : Tensor,
# '[equation/output_points/conditional_variables]': Tensor}
# }
# those variables are used for the dataloading
self._data_collections = {name : {} for name in self.problem.conditions}
# variables used to check that all conditions are sampled
self._is_conditions_ready = {
name : False for name in self.problem.conditions}
self.full = False
@property
def full(self):
return all(self.is_conditions_ready.values())
return all(self._is_conditions_ready.values())
@full.setter
def full(self, value):
check_consistency(value, bool)
self._full = value
@property
def data_collections(self):
return self._data_collections
@property
def problem(self):
return self._problem
@@ -33,13 +47,13 @@ class Collector:
for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready
self.is_conditions_ready[condition_name] = True
self._is_conditions_ready[condition_name] = True
def store_sample_domains(self, n, mode, variables, sample_locations):
# loop over all locations
@@ -48,7 +62,7 @@ class Collector:
condition = self.problem.conditions[loc]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self.is_conditions_ready[loc]):
if (not self._is_conditions_ready[loc]):
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
@@ -57,7 +71,7 @@ class Collector:
already_sampled = [self.data_collections[loc]['input_points']]
# if the condition is ready but we want to sample again
else:
self.is_conditions_ready[loc] = False
self._is_conditions_ready[loc] = False
already_sampled = []
# get the samples
@@ -70,7 +84,7 @@ class Collector:
):
pts = pts.sort_labels()
if sorted(pts.labels)==sorted(self.problem.input_variables):
self.is_conditions_ready[loc] = True
self._is_conditions_ready[loc] = True
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))
else:
@@ -84,6 +98,6 @@ class Collector:
:raises RuntimeError: if at least one condition is not already sampled
"""
for k,v in new_points_dict.items():
if not self.is_conditions_ready[k]:
if not self._is_conditions_ready[k]:
raise RuntimeError('Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)

View File

@@ -13,20 +13,20 @@ class DataConditionInterface(ConditionInterface):
distribution
"""
__slots__ = ["data", "conditionalvariable"]
__slots__ = ["input_points", "conditional_variables"]
def __init__(self, data, conditionalvariable=None):
def __init__(self, input_points, conditional_variables=None):
"""
TODO
"""
super().__init__()
self.data = data
self.conditionalvariable = conditionalvariable
self.input_points = input_points
self.conditional_variables = conditional_variables
self.condition_type = 'unsupervised'
def __setattr__(self, key, value):
if (key == 'data') or (key == 'conditionalvariable'):
if (key == 'input_points') or (key == 'conditional_variables'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)

View File

@@ -29,5 +29,5 @@ class DomainEquationCondition(ConditionInterface):
elif key == 'equation':
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)

View File

@@ -30,5 +30,5 @@ class InputPointsEquationCondition(ConditionInterface):
elif key == 'equation':
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)

View File

@@ -27,5 +27,5 @@ class InputOutputPointsCondition(ConditionInterface):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
elif key in ('problem', 'condition_type'):
super().__setattr__(key, value)

View File

@@ -20,7 +20,7 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
# create collector to manage problem data
self.collector = Collector(self)
self._collector = Collector(self)
# create hook conditions <-> problems
for condition_name in self.conditions:
@@ -33,7 +33,12 @@ class AbstractProblem(metaclass=ABCMeta):
# points are ready.
self.collector.store_fixed_data()
@property
def collector(self):
return self._collector
# TODO this should be erase when dataloading will interface collector,
# kept only for back compatibility
@property
def input_pts(self):
to_return = {}
@@ -41,10 +46,6 @@ class AbstractProblem(metaclass=ABCMeta):
if 'input_points' in v.keys():
to_return[k] = v['input_points']
return to_return
@property
def _have_sampled_points(self):
return self.collector.is_conditions_ready
def __deepcopy__(self, memo):
"""
@@ -160,7 +161,9 @@ class AbstractProblem(metaclass=ABCMeta):
# check correct location
if locations == "all":
locations = [name for name in self.conditions.keys()]
locations = [name for name in self.conditions.keys()
if isinstance(self.conditions[name],
DomainEquationCondition)]
else:
if not isinstance(locations, (list)):
locations = [locations]
@@ -168,7 +171,7 @@ class AbstractProblem(metaclass=ABCMeta):
if not isinstance(self.conditions[loc], DomainEquationCondition):
raise TypeError(
f"Wrong locations passed, locations for sampling "
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
)
# store data

View File

@@ -32,29 +32,18 @@ class Trainer(pytorch_lightning.Trainer):
if batch_size is not None:
check_consistency(batch_size, int)
self._model = solver
self.solver = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
# create dataloader
# if solver.problem.have_sampled_points is False:
# raise RuntimeError(
# f"Input points in {solver.problem.not_sampled_points} "
# "training are None. Please "
# "sample points in your problem by calling "
# "discretise_domain function before train "
# "in the provided locations."
# )
# self._create_or_update_loader()
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self._model.problem
pb = self.solver.problem
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
@@ -67,14 +56,21 @@ class Trainer(pytorch_lightning.Trainer):
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
devices = self._accelerator_connector._parallel_devices
if len(devices) > 1:
raise RuntimeError("Parallel training is not supported yet.")
device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device)
dataset_data = DataPointDataset(self._model.problem, device)
dataset_phys = SamplePointDataset(self.solver.problem, device)
dataset_data = DataPointDataset(self.solver.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
)
@@ -84,7 +80,7 @@ class Trainer(pytorch_lightning.Trainer):
Train the solver method.
"""
return super().fit(
self._model, train_dataloaders=self._loader, **kwargs
self.solver, train_dataloaders=self._loader, **kwargs
)
@property
@@ -92,4 +88,4 @@ class Trainer(pytorch_lightning.Trainer):
"""
Returning trainer solver.
"""
return self._model
return self._solver

View File

@@ -89,6 +89,7 @@ def test_discretise_domain():
poisson_problem.discretise_domain(n, 'lh', locations=['D'])
assert poisson_problem.input_pts['D'].shape[0] == n
poisson_problem.discretise_domain(n)
def test_sampling_few_variables():
n = 10
@@ -98,7 +99,7 @@ def test_sampling_few_variables():
locations=['D'],
variables=['x'])
assert poisson_problem.input_pts['D'].shape[1] == 1
assert poisson_problem._have_sampled_points['D'] is False
assert poisson_problem.collector._is_conditions_ready['D'] is False
def test_variables_correct_order_sampling():
@@ -140,3 +141,12 @@ def test_add_points():
poisson_problem.add_points({'D': new_pts})
assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x'))
assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y'))
def test_collector():
poisson_problem = Poisson()
collector = poisson_problem.collector
assert collector.full is False
assert collector._is_conditions_ready['data'] is True
poisson_problem.discretise_domain(10)
assert collector.full is True