minor changes/ trainer update
This commit is contained in:
committed by
Nicola Demo
parent
7528f6ef74
commit
b9753c34b2
@@ -5,21 +5,35 @@ from .utils import check_consistency, merge_tensors
|
||||
|
||||
class Collector:
|
||||
def __init__(self, problem):
|
||||
self.problem = problem # hook Collector <-> Problem
|
||||
self.data_collections = {name : {} for name in self.problem.conditions} # collection of data
|
||||
self.is_conditions_ready = {
|
||||
name : False for name in self.problem.conditions} # names of the conditions that need to be sampled
|
||||
self.full = False # collector full, all points for all conditions are given and the data are ready to be used in trainig
|
||||
# creating a hook between collector and problem
|
||||
self.problem = problem
|
||||
|
||||
# this variable is used to store the data in the form:
|
||||
# {'[condition_name]' :
|
||||
# {'input_points' : Tensor,
|
||||
# '[equation/output_points/conditional_variables]': Tensor}
|
||||
# }
|
||||
# those variables are used for the dataloading
|
||||
self._data_collections = {name : {} for name in self.problem.conditions}
|
||||
|
||||
# variables used to check that all conditions are sampled
|
||||
self._is_conditions_ready = {
|
||||
name : False for name in self.problem.conditions}
|
||||
self.full = False
|
||||
|
||||
@property
|
||||
def full(self):
|
||||
return all(self.is_conditions_ready.values())
|
||||
return all(self._is_conditions_ready.values())
|
||||
|
||||
@full.setter
|
||||
def full(self, value):
|
||||
check_consistency(value, bool)
|
||||
self._full = value
|
||||
|
||||
@property
|
||||
def data_collections(self):
|
||||
return self._data_collections
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
return self._problem
|
||||
@@ -33,13 +47,13 @@ class Collector:
|
||||
for condition_name, condition in self.problem.conditions.items():
|
||||
# if the condition is not ready and domain is not attribute
|
||||
# of condition, we get and store the data
|
||||
if (not self.is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
||||
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
|
||||
# get data
|
||||
keys = condition.__slots__
|
||||
values = [getattr(condition, name) for name in keys]
|
||||
self.data_collections[condition_name] = dict(zip(keys, values))
|
||||
# condition now is ready
|
||||
self.is_conditions_ready[condition_name] = True
|
||||
self._is_conditions_ready[condition_name] = True
|
||||
|
||||
def store_sample_domains(self, n, mode, variables, sample_locations):
|
||||
# loop over all locations
|
||||
@@ -48,7 +62,7 @@ class Collector:
|
||||
condition = self.problem.conditions[loc]
|
||||
keys = ["input_points", "equation"]
|
||||
# if the condition is not ready, we get and store the data
|
||||
if (not self.is_conditions_ready[loc]):
|
||||
if (not self._is_conditions_ready[loc]):
|
||||
# if it is the first time we sample
|
||||
if not self.data_collections[loc]:
|
||||
already_sampled = []
|
||||
@@ -57,7 +71,7 @@ class Collector:
|
||||
already_sampled = [self.data_collections[loc]['input_points']]
|
||||
# if the condition is ready but we want to sample again
|
||||
else:
|
||||
self.is_conditions_ready[loc] = False
|
||||
self._is_conditions_ready[loc] = False
|
||||
already_sampled = []
|
||||
|
||||
# get the samples
|
||||
@@ -70,7 +84,7 @@ class Collector:
|
||||
):
|
||||
pts = pts.sort_labels()
|
||||
if sorted(pts.labels)==sorted(self.problem.input_variables):
|
||||
self.is_conditions_ready[loc] = True
|
||||
self._is_conditions_ready[loc] = True
|
||||
values = [pts, condition.equation]
|
||||
self.data_collections[loc] = dict(zip(keys, values))
|
||||
else:
|
||||
@@ -84,6 +98,6 @@ class Collector:
|
||||
:raises RuntimeError: if at least one condition is not already sampled
|
||||
"""
|
||||
for k,v in new_points_dict.items():
|
||||
if not self.is_conditions_ready[k]:
|
||||
if not self._is_conditions_ready[k]:
|
||||
raise RuntimeError('Cannot add points on a non sampled condition')
|
||||
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
|
||||
@@ -13,20 +13,20 @@ class DataConditionInterface(ConditionInterface):
|
||||
distribution
|
||||
"""
|
||||
|
||||
__slots__ = ["data", "conditionalvariable"]
|
||||
__slots__ = ["input_points", "conditional_variables"]
|
||||
|
||||
def __init__(self, data, conditionalvariable=None):
|
||||
def __init__(self, input_points, conditional_variables=None):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.conditionalvariable = conditionalvariable
|
||||
self.input_points = input_points
|
||||
self.conditional_variables = conditional_variables
|
||||
self.condition_type = 'unsupervised'
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'data') or (key == 'conditionalvariable'):
|
||||
if (key == 'input_points') or (key == 'conditional_variables'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
DataConditionInterface.__dict__[key].__set__(self, value)
|
||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
||||
elif key in ('problem', 'condition_type'):
|
||||
super().__setattr__(key, value)
|
||||
@@ -29,5 +29,5 @@ class DomainEquationCondition(ConditionInterface):
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
||||
elif key in ('problem', 'condition_type'):
|
||||
super().__setattr__(key, value)
|
||||
@@ -30,5 +30,5 @@ class InputPointsEquationCondition(ConditionInterface):
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
||||
elif key in ('problem', 'condition_type'):
|
||||
super().__setattr__(key, value)
|
||||
@@ -27,5 +27,5 @@ class InputOutputPointsCondition(ConditionInterface):
|
||||
if (key == 'input_points') or (key == 'output_points'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
||||
elif key in ('_condition_type', '_problem', 'problem', 'condition_type'):
|
||||
elif key in ('problem', 'condition_type'):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@@ -20,7 +20,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
def __init__(self):
|
||||
|
||||
# create collector to manage problem data
|
||||
self.collector = Collector(self)
|
||||
self._collector = Collector(self)
|
||||
|
||||
# create hook conditions <-> problems
|
||||
for condition_name in self.conditions:
|
||||
@@ -33,7 +33,12 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# points are ready.
|
||||
self.collector.store_fixed_data()
|
||||
|
||||
@property
|
||||
def collector(self):
|
||||
return self._collector
|
||||
|
||||
# TODO this should be erase when dataloading will interface collector,
|
||||
# kept only for back compatibility
|
||||
@property
|
||||
def input_pts(self):
|
||||
to_return = {}
|
||||
@@ -41,10 +46,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
if 'input_points' in v.keys():
|
||||
to_return[k] = v['input_points']
|
||||
return to_return
|
||||
|
||||
@property
|
||||
def _have_sampled_points(self):
|
||||
return self.collector.is_conditions_ready
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
@@ -160,7 +161,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
# check correct location
|
||||
if locations == "all":
|
||||
locations = [name for name in self.conditions.keys()]
|
||||
locations = [name for name in self.conditions.keys()
|
||||
if isinstance(self.conditions[name],
|
||||
DomainEquationCondition)]
|
||||
else:
|
||||
if not isinstance(locations, (list)):
|
||||
locations = [locations]
|
||||
@@ -168,7 +171,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
if not isinstance(self.conditions[loc], DomainEquationCondition):
|
||||
raise TypeError(
|
||||
f"Wrong locations passed, locations for sampling "
|
||||
f"should be in {[loc for loc in locations if not isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.",
|
||||
)
|
||||
|
||||
# store data
|
||||
|
||||
@@ -32,29 +32,18 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
|
||||
self._model = solver
|
||||
self.solver = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
|
||||
# create dataloader
|
||||
# if solver.problem.have_sampled_points is False:
|
||||
# raise RuntimeError(
|
||||
# f"Input points in {solver.problem.not_sampled_points} "
|
||||
# "training are None. Please "
|
||||
# "sample points in your problem by calling "
|
||||
# "discretise_domain function before train "
|
||||
# "in the provided locations."
|
||||
# )
|
||||
|
||||
# self._create_or_update_loader()
|
||||
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
|
||||
# move parameters to device
|
||||
pb = self._model.problem
|
||||
pb = self.solver.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
@@ -67,14 +56,21 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
during training, there is no need to define to touch the
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join(
|
||||
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
if len(devices) > 1:
|
||||
raise RuntimeError("Parallel training is not supported yet.")
|
||||
|
||||
device = devices[0]
|
||||
dataset_phys = SamplePointDataset(self._model.problem, device)
|
||||
dataset_data = DataPointDataset(self._model.problem, device)
|
||||
dataset_phys = SamplePointDataset(self.solver.problem, device)
|
||||
dataset_data = DataPointDataset(self.solver.problem, device)
|
||||
self._loader = SamplePointLoader(
|
||||
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
||||
)
|
||||
@@ -84,7 +80,7 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(
|
||||
self._model, train_dataloaders=self._loader, **kwargs
|
||||
self.solver, train_dataloaders=self._loader, **kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -92,4 +88,4 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
"""
|
||||
Returning trainer solver.
|
||||
"""
|
||||
return self._model
|
||||
return self._solver
|
||||
|
||||
Reference in New Issue
Block a user