minor changes/ trainer update
This commit is contained in:
committed by
Nicola Demo
parent
7528f6ef74
commit
b9753c34b2
@@ -5,21 +5,35 @@ from .utils import check_consistency, merge_tensors
|
|||||||
|
|
||||||
class Collector:
|
class Collector:
|
||||||
def __init__(self, problem):
|
def __init__(self, problem):
|
||||||
self.problem = problem # hook Collector <-> Problem
|
# creating a hook between collector and problem
|
||||||
self.data_collections = {name : {} for name in self.problem.conditions} # collection of data
|
self.problem = problem
|
||||||
self.is_conditions_ready = {
|
|
||||||
name : False for name in self.problem.conditions} # names of the conditions that need to be sampled
|
# this variable is used to store the data in the form:
|
||||||
self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig
|
# {'[condition_name]' :
|
||||||
|
# {'input_points' : Tensor,
|
||||||
|
# '[equation/output_points/conditional_variables]': Tensor}
|
||||||
|
# }
|
||||||
|
# those variables are used for the dataloading
|
||||||
|
self._data_collections = {name : {} for name in self.problem.conditions}
|
||||||
|
|
||||||
|
# variables used to check that all conditions are sampled
|
||||||
|
self._is_conditions_ready = {
|
||||||
|
name : False for name in self.problem.conditions}
|
||||||
|
self.full = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full(self):
|
def full(self):
|
||||||
return all(self.is_conditions_ready.values())
|
return all(self._is_conditions_ready.values())
|
||||||
|
|
||||||
@full.setter
|
@full.setter
|
||||||
def full(self, value):
|
def full(self, value):
|
||||||
check_consistency(value, bool)
|
check_consistency(value, bool)
|
||||||
self._full = value
|
self._full = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_collections(self):
|
||||||
|
return self._data_collections
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def problem(self):
|
def problem(self):
|
||||||
return self._problem
|
return self._problem
|
||||||
@@ -33,13 +47,13 @@ class Collector:
|
|||||||
for condition_name, condition in self.problem.conditions.items():
|
for condition_name, condition in self.problem.conditions.items():
|
||||||
# if the condition is not ready and domain is not attribute
|
# if the condition is not ready and domain is not attribute
|
||||||
# of condition, we get and store the data
|
# of condition, we get and store the data
|
||||||
if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
||||||
# get data
|
# get data
|
||||||
keys = condition.__slots__
|
keys = condition.__slots__
|
||||||
values = [getattr(condition, name) for name in keys]
|
values = [getattr(condition, name) for name in keys]
|
||||||
self.data_collections[condition_name] = dict(zip(keys, values))
|
self.data_collections[condition_name] = dict(zip(keys, values))
|
||||||
# condition now is ready
|
# condition now is ready
|
||||||
self.is_conditions_ready[condition_name] = True
|
self._is_conditions_ready[condition_name] = True
|
||||||
|
|
||||||
def store_sample_domains(self, n, mode, variables, sample_locations):
|
def store_sample_domains(self, n, mode, variables, sample_locations):
|
||||||
# loop over all locations
|
# loop over all locations
|
||||||
@@ -48,7 +62,7 @@ class Collector:
|
|||||||
condition = self.problem.conditions[loc]
|
condition = self.problem.conditions[loc]
|
||||||
keys = ["input_points", "equation"]
|
keys = ["input_points", "equation"]
|
||||||
# if the condition is not ready, we get and store the data
|
# if the condition is not ready, we get and store the data
|
||||||
if (not self.is_conditions_ready[loc]):
|
if (not self._is_conditions_ready[loc]):
|
||||||
# if it is the first time we sample
|
# if it is the first time we sample
|
||||||
if not self.data_collections[loc]:
|
if not self.data_collections[loc]:
|
||||||
already_sampled = []
|
already_sampled = []
|
||||||
@@ -57,7 +71,7 @@ class Collector:
|
|||||||
already_sampled = [self.data_collections[loc]['input_points']]
|
already_sampled = [self.data_collections[loc]['input_points']]
|
||||||
# if the condition is ready but we want to sample again
|
# if the condition is ready but we want to sample again
|
||||||
else:
|
else:
|
||||||
self.is_conditions_ready[loc] = False
|
self._is_conditions_ready[loc] = False
|
||||||
already_sampled = []
|
already_sampled = []
|
||||||
|
|
||||||
# get the samples
|
# get the samples
|
||||||
@@ -70,7 +84,7 @@ class Collector:
|
|||||||
):
|
):
|
||||||
pts = pts.sort_labels()
|
pts = pts.sort_labels()
|
||||||
if sorted(pts.labels)==sorted(self.problem.input_variables):
|
if sorted(pts.labels)==sorted(self.problem.input_variables):
|
||||||
self.is_conditions_ready[loc] = True
|
self._is_conditions_ready[loc] = True
|
||||||
values = [pts, condition.equation]
|
values = [pts, condition.equation]
|
||||||
self.data_collections[loc] = dict(zip(keys, values))
|
self.data_collections[loc] = dict(zip(keys, values))
|
||||||
else:
|
else:
|
||||||
@@ -84,6 +98,6 @@ class Collector:
|
|||||||
:raises RuntimeError: if at least one condition is not already sampled
|
:raises RuntimeError: if at least one condition is not already sampled
|
||||||
"""
|
"""
|
||||||
for k,v in new_points_dict.items():
|
for k,v in new_points_dict.items():
|
||||||
if not self.is_conditions_ready[k]:
|
if not self._is_conditions_ready[k]:
|
||||||
raise RuntimeError('Cannot add points on a non sampled condition')
|
raise RuntimeError('Cannot add points on a non sampled condition')
|
||||||
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
|
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
|
||||||
@@ -13,20 +13,20 @@ class DataConditionInterface(ConditionInterface):
|
|||||||
distribution
|
distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["data", "conditionalvariable"]
|
__slots__ = ["input_points", "conditional_variables"]
|
||||||
|
|
||||||
def __init__(self, data, conditionalvariable=None):
|
def __init__(self, input_points, conditional_variables=None):
|
||||||
"""
|
"""
|
||||||
TODO
|
TODO
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = data
|
self.input_points = input_points
|
||||||
self.conditionalvariable = conditionalvariable
|
self.conditional_variables = conditional_variables
|
||||||
self.condition_type = 'unsupervised'
|
self.condition_type = 'unsupervised'
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if (key == 'data') or (key == 'conditionalvariable'):
|
if (key == 'input_points') or (key == 'conditional_variables'):
|
||||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
DataConditionInterface.__dict__[key].__set__(self, value)
|
DataConditionInterface.__dict__[key].__set__(self, value)
|
||||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
elif key in ('problem', 'condition_type'):
|
||||||
super().__setattr__(key, value)
|
super().__setattr__(key, value)
|
||||||
@@ -29,5 +29,5 @@ class DomainEquationCondition(ConditionInterface):
|
|||||||
elif key == 'equation':
|
elif key == 'equation':
|
||||||
check_consistency(value, (EquationInterface))
|
check_consistency(value, (EquationInterface))
|
||||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
elif key in ('problem', 'condition_type'):
|
||||||
super().__setattr__(key, value)
|
super().__setattr__(key, value)
|
||||||
@@ -30,5 +30,5 @@ class InputPointsEquationCondition(ConditionInterface):
|
|||||||
elif key == 'equation':
|
elif key == 'equation':
|
||||||
check_consistency(value, (EquationInterface))
|
check_consistency(value, (EquationInterface))
|
||||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
elif key in ('problem', 'condition_type'):
|
||||||
super().__setattr__(key, value)
|
super().__setattr__(key, value)
|
||||||
@@ -27,5 +27,5 @@ class InputOutputPointsCondition(ConditionInterface):
|
|||||||
if (key == 'input_points') or (key == 'output_points'):
|
if (key == 'input_points') or (key == 'output_points'):
|
||||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||||
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
||||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
elif key in ('problem', 'condition_type'):
|
||||||
super().__setattr__(key, value)
|
super().__setattr__(key, value)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
# create collector to manage problem data
|
# create collector to manage problem data
|
||||||
self.collector = Collector(self)
|
self._collector = Collector(self)
|
||||||
|
|
||||||
# create hook conditions <-> problems
|
# create hook conditions <-> problems
|
||||||
for condition_name in self.conditions:
|
for condition_name in self.conditions:
|
||||||
@@ -33,7 +33,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
# points are ready.
|
# points are ready.
|
||||||
self.collector.store_fixed_data()
|
self.collector.store_fixed_data()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collector(self):
|
||||||
|
return self._collector
|
||||||
|
|
||||||
|
# TODO this should be erase when dataloading will interface collector,
|
||||||
|
# kept only for back compatibility
|
||||||
@property
|
@property
|
||||||
def input_pts(self):
|
def input_pts(self):
|
||||||
to_return = {}
|
to_return = {}
|
||||||
@@ -42,10 +47,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
to_return[k] = v['input_points']
|
to_return[k] = v['input_points']
|
||||||
return to_return
|
return to_return
|
||||||
|
|
||||||
@property
|
|
||||||
def _have_sampled_points(self):
|
|
||||||
return self.collector.is_conditions_ready
|
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
"""
|
"""
|
||||||
Implements deepcopy for the
|
Implements deepcopy for the
|
||||||
@@ -160,7 +161,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
|
|
||||||
# check correct location
|
# check correct location
|
||||||
if locations == "all":
|
if locations == "all":
|
||||||
locations = [name for name in self.conditions.keys()]
|
locations = [name for name in self.conditions.keys()
|
||||||
|
if isinstance(self.conditions[name],
|
||||||
|
DomainEquationCondition)]
|
||||||
else:
|
else:
|
||||||
if not isinstance(locations, (list)):
|
if not isinstance(locations, (list)):
|
||||||
locations = [locations]
|
locations = [locations]
|
||||||
@@ -168,7 +171,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Wrong locations passed, locations for sampling "
|
f"Wrong locations passed, locations for sampling "
|
||||||
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# store data
|
# store data
|
||||||
|
|||||||
@@ -32,29 +32,18 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
if batch_size is not None:
|
if batch_size is not None:
|
||||||
check_consistency(batch_size, int)
|
check_consistency(batch_size, int)
|
||||||
|
|
||||||
self._model = solver
|
self.solver = solver
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
self._create_loader()
|
self._create_loader()
|
||||||
self._move_to_device()
|
self._move_to_device()
|
||||||
|
|
||||||
# create dataloader
|
|
||||||
# if solver.problem.have_sampled_points is False:
|
|
||||||
# raise RuntimeError(
|
|
||||||
# f"Input points in {solver.problem.not_sampled_points} "
|
|
||||||
# "training are None. Please "
|
|
||||||
# "sample points in your problem by calling "
|
|
||||||
# "discretise_domain function before train "
|
|
||||||
# "in the provided locations."
|
|
||||||
# )
|
|
||||||
|
|
||||||
# self._create_or_update_loader()
|
|
||||||
|
|
||||||
def _move_to_device(self):
|
def _move_to_device(self):
|
||||||
device = self._accelerator_connector._parallel_devices[0]
|
device = self._accelerator_connector._parallel_devices[0]
|
||||||
|
|
||||||
# move parameters to device
|
# move parameters to device
|
||||||
pb = self._model.problem
|
pb = self.solver.problem
|
||||||
if hasattr(pb, "unknown_parameters"):
|
if hasattr(pb, "unknown_parameters"):
|
||||||
for key in pb.unknown_parameters:
|
for key in pb.unknown_parameters:
|
||||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||||
@@ -67,14 +56,21 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
during training, there is no need to define to touch the
|
during training, there is no need to define to touch the
|
||||||
trainer dataloader, just call the method.
|
trainer dataloader, just call the method.
|
||||||
"""
|
"""
|
||||||
|
if not self.solver.problem.collector.full:
|
||||||
|
error_message = '\n'.join(
|
||||||
|
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||||
|
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
|
||||||
|
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||||
|
'are sampled. The Trainer got the following:\n'
|
||||||
|
f'{error_message}')
|
||||||
devices = self._accelerator_connector._parallel_devices
|
devices = self._accelerator_connector._parallel_devices
|
||||||
|
|
||||||
if len(devices) > 1:
|
if len(devices) > 1:
|
||||||
raise RuntimeError("Parallel training is not supported yet.")
|
raise RuntimeError("Parallel training is not supported yet.")
|
||||||
|
|
||||||
device = devices[0]
|
device = devices[0]
|
||||||
dataset_phys = SamplePointDataset(self._model.problem, device)
|
dataset_phys = SamplePointDataset(self.solver.problem, device)
|
||||||
dataset_data = DataPointDataset(self._model.problem, device)
|
dataset_data = DataPointDataset(self.solver.problem, device)
|
||||||
self._loader = SamplePointLoader(
|
self._loader = SamplePointLoader(
|
||||||
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
||||||
)
|
)
|
||||||
@@ -84,7 +80,7 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
Train the solver method.
|
Train the solver method.
|
||||||
"""
|
"""
|
||||||
return super().fit(
|
return super().fit(
|
||||||
self._model, train_dataloaders=self._loader, **kwargs
|
self.solver, train_dataloaders=self._loader, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -92,4 +88,4 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
"""
|
"""
|
||||||
Returning trainer solver.
|
Returning trainer solver.
|
||||||
"""
|
"""
|
||||||
return self._model
|
return self._solver
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ def test_discretise_domain():
|
|||||||
poisson_problem.discretise_domain(n, 'lh', locations=['D'])
|
poisson_problem.discretise_domain(n, 'lh', locations=['D'])
|
||||||
assert poisson_problem.input_pts['D'].shape[0] == n
|
assert poisson_problem.input_pts['D'].shape[0] == n
|
||||||
|
|
||||||
|
poisson_problem.discretise_domain(n)
|
||||||
|
|
||||||
def test_sampling_few_variables():
|
def test_sampling_few_variables():
|
||||||
n = 10
|
n = 10
|
||||||
@@ -98,7 +99,7 @@ def test_sampling_few_variables():
|
|||||||
locations=['D'],
|
locations=['D'],
|
||||||
variables=['x'])
|
variables=['x'])
|
||||||
assert poisson_problem.input_pts['D'].shape[1] == 1
|
assert poisson_problem.input_pts['D'].shape[1] == 1
|
||||||
assert poisson_problem._have_sampled_points['D'] is False
|
assert poisson_problem.collector._is_conditions_ready['D'] is False
|
||||||
|
|
||||||
|
|
||||||
def test_variables_correct_order_sampling():
|
def test_variables_correct_order_sampling():
|
||||||
@@ -140,3 +141,12 @@ def test_add_points():
|
|||||||
poisson_problem.add_points({'D': new_pts})
|
poisson_problem.add_points({'D': new_pts})
|
||||||
assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x'))
|
assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x'))
|
||||||
assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y'))
|
assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y'))
|
||||||
|
|
||||||
|
|
||||||
|
def test_collector():
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
collector = poisson_problem.collector
|
||||||
|
assert collector.full is False
|
||||||
|
assert collector._is_conditions_ready['data'] is True
|
||||||
|
poisson_problem.discretise_domain(10)
|
||||||
|
assert collector.full is True
|
||||||
|
|||||||
Reference in New Issue
Block a user