Add Graph support in Dataset and Dataloader

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent eb146ea2ea
commit ccc5f5a322
11 changed files with 125 additions and 75 deletions

View File

@@ -38,7 +38,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
check_consistency(problem, AbstractProblem)
self._check_solver_consistency(problem)
#Check consistency of models argument and encapsulate in list
# Check consistency of models argument and encapsulate in list
if not isinstance(models, list):
check_consistency(models, torch.nn.Module)
# put everything in a list if only one input
@@ -49,17 +49,17 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
check_consistency(models[idx], torch.nn.Module)
len_model = len(models)
#If use_lt is true add extract operation in input
# If use_lt is true add extract operation in input
if use_lt is True:
for idx in range(len(models)):
for idx, model in enumerate(models):
models[idx] = Network(
model=models[idx],
model=model,
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features,
)
#Check scheduler consistency + encapsulation
# Check scheduler consistency + encapsulation
if not isinstance(schedulers, list):
check_consistency(schedulers, Scheduler)
schedulers = [schedulers]
@@ -67,7 +67,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
for scheduler in schedulers:
check_consistency(scheduler, Scheduler)
#Check optimizer consistency + encapsulation
# Check optimizer consistency + encapsulation
if not isinstance(optimizers, list):
check_consistency(optimizers, Optimizer)
optimizers = [optimizers]
@@ -141,5 +141,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
if not set(self.accepted_condition_types).issubset(
condition.condition_type):
raise ValueError(
f'{self.__name__} support only dose not support condition {condition.condition_type}'
f'{self.__name__} support only dose not support condition '
f'{condition.condition_type}'
)