""" Solver module. """ import lightning import torch import sys from abc import ABCMeta, abstractmethod from ..problem import AbstractProblem from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler from ..loss import WeightingInterface from ..loss.scalar_weighting import _NoWeighting from ..utils import check_consistency, labelize_forward from torch._dynamo.eval_frame import OptimizedModule class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ SolverInterface base class. This class is a wrapper of LightningModule. """ def __init__(self, problem, weighting, use_lt): """ :param problem: A problem definition instance. :type problem: AbstractProblem :param weighting: The loss weighting to use. :type weighting: WeightingInterface :param use_lt: Using LabelTensors as input during training. :type use_lt: bool """ super().__init__() # check consistency of the problem check_consistency(problem, AbstractProblem) self._check_solver_consistency(problem) self._pina_problem = problem # check consistency of the weighting and hook the condition names if weighting is None: weighting = _NoWeighting() check_consistency(weighting, WeightingInterface) self._pina_weighting = weighting weighting.condition_names = list(self._pina_problem.conditions.keys()) # check consistency use_lt check_consistency(use_lt, bool) self._use_lt = use_lt # if use_lt is true add extract operation in input if use_lt is True: self.forward = labelize_forward( forward=self.forward, input_variables=problem.input_variables, output_variables=problem.output_variables, ) # PINA private attributes (some are overridden by derived classes) self._pina_problem = problem self._pina_models = None self._pina_optimizers = None self._pina_schedulers = None def _check_solver_consistency(self, problem): for condition in problem.conditions.values(): check_consistency(condition, self.accepted_conditions_types) def _optimization_cycle(self, batch): """ Perform a private optimization cycle by computing the loss for each condition in the given batch. The loss are later aggregated using the specific weighting schema. :param batch: A batch of data, where each element is a tuple containing a condition name and a dictionary of points. :type batch: list of tuples (str, dict) :return: The computed loss for the all conditions in the batch, cast to a subclass of `torch.Tensor`. It should return a dict containing the condition name and the associated scalar loss. :rtype: dict(torch.Tensor) """ losses = self.optimization_cycle(batch) for name, value in losses.items(): self.store_log(f'{name}_loss', value.item(), self.get_batch_size(batch)) loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) return loss def training_step(self, batch): """ Solver training step. :param batch: The batch element in the dataloader. :type batch: tuple :return: The sum of the loss functions. :rtype: LabelTensor """ loss = self._optimization_cycle(batch=batch) self.store_log('train_loss', loss, self.get_batch_size(batch)) return loss def validation_step(self, batch): """ Solver validation step. :param batch: The batch element in the dataloader. :type batch: tuple """ loss = self._optimization_cycle(batch=batch) self.store_log('val_loss', loss, self.get_batch_size(batch)) def test_step(self, batch): """ Solver test step. :param batch: The batch element in the dataloader. :type batch: tuple """ loss = self._optimization_cycle(batch=batch) self.store_log('test_loss', loss, self.get_batch_size(batch)) def store_log(self, name, value, batch_size): self.log(name=name, value=value, batch_size=batch_size, **self.trainer.logging_kwargs ) @abstractmethod def forward(self, *args, **kwargs): pass @abstractmethod def optimization_cycle(self, batch): """ Perform an optimization cycle by computing the loss for each condition in the given batch. :param batch: A batch of data, where each element is a tuple containing a condition name and a dictionary of points. :type batch: list of tuples (str, dict) :return: The computed loss for the all conditions in the batch, cast to a subclass of `torch.Tensor`. It should return a dict containing the condition name and the associated scalar loss. :rtype: dict(torch.Tensor) """ pass @property def problem(self): """ The problem formulation. """ return self._pina_problem @property def use_lt(self): """ Using LabelTensor in training. """ return self._use_lt @property def weighting(self): """ The weighting mechanism. """ return self._pina_weighting @staticmethod def get_batch_size(batch): # assuming batch is a custom Batch object batch_size = 0 for data in batch: batch_size += len(data[1]['input_points']) return batch_size @staticmethod def default_torch_optimizer(): return TorchOptimizer(torch.optim.Adam, lr=0.001) @staticmethod def default_torch_scheduler(): return TorchScheduler(torch.optim.lr_scheduler.ConstantLR) def on_train_start(self): """ Hook that is called before training begins. Used to compile the model if the trainer is set to compile. """ super().on_train_start() if self.trainer.compile: self._compile_model() def on_test_start(self): """ Hook that is called before training begins. Used to compile the model if the trainer is set to compile. """ super().on_train_start() if self.trainer.compile and not self._check_already_compiled(): self._compile_model() def _check_already_compiled(self): models = self._pina_models if len(models) == 1 and isinstance(self._pina_models[0], torch.nn.ModuleDict): models = list(self._pina_models.values()) for model in models: if not isinstance(model, (OptimizedModule, torch.nn.ModuleDict)): return False return True @staticmethod def _perform_compilation(model): model_device = next(model.parameters()).device try: if model_device == torch.device("mps:0"): model = torch.compile(model, backend="eager") else: model = torch.compile(model, backend="inductor") except Exception as e: print("Compilation failed, running in normal mode.:\n", e) return model class SingleSolverInterface(SolverInterface): def __init__(self, problem, model, optimizer=None, scheduler=None, weighting=None, use_lt=True): """ :param problem: A problem definition instance. :type problem: AbstractProblem :param model: A torch nn.Module instances. :type model: torch.nn.Module :param Optimizer optimizers: A neural network optimizers to use. :param Scheduler optimizers: A neural network scheduler to use. :param WeightingInterface weighting: The loss weighting to use. :param bool use_lt: Using LabelTensors as input during training. """ if optimizer is None: optimizer = self.default_torch_optimizer() if scheduler is None: scheduler = self.default_torch_scheduler() super().__init__(problem=problem, use_lt=use_lt, weighting=weighting) # check consistency of models argument and encapsulate in list check_consistency(model, torch.nn.Module) # check scheduler consistency and encapsulate in list check_consistency(scheduler, Scheduler) # check optimizer consistency and encapsulate in list check_consistency(optimizer, Optimizer) # initialize the model (needed by Lightining to go to different devices) self._pina_models = torch.nn.ModuleList([model]) self._pina_optimizers = [optimizer] self._pina_schedulers = [scheduler] def forward(self, x): """ Forward pass implementation for the solver. :param torch.Tensor x: Input tensor. :return: Solver solution. :rtype: torch.Tensor """ x = self.model(x) return x def configure_optimizers(self): """ Optimizer configuration for the solver. :return: The optimizers and the schedulers :rtype: tuple(list, list) """ self.optimizer.hook(self.model.parameters()) self.scheduler.hook(self.optimizer) return ( [self.optimizer.instance], [self.scheduler.instance] ) def _compile_model(self): if isinstance(self._pina_models[0], torch.nn.ModuleDict): self._compile_module_dict() else: self._compile_single_model() def _compile_module_dict(self): for name, model in self._pina_models[0].items(): self._pina_models[0][name] = self._perform_compilation(model) def _compile_single_model(self): self._pina_models[0] = self._perform_compilation(self._pina_models[0]) @property def model(self): """ Model for training. """ return self._pina_models[0] @property def scheduler(self): """ Scheduler for training. """ return self._pina_schedulers[0] @property def optimizer(self): """ Optimizer for training. """ return self._pina_optimizers[0] class MultiSolverInterface(SolverInterface): """ Multiple Solver base class. This class inherits is a wrapper of SolverInterface class """ def __init__(self, problem, models, optimizers=None, schedulers=None, weighting=None, use_lt=True): """ :param problem: A problem definition instance. :type problem: AbstractProblem :param models: Multiple torch nn.Module instances. :type model: list[torch.nn.Module] | tuple[torch.nn.Module] :param list(Optimizer) optimizers: A list of neural network optimizers to use. :param list(Scheduler) optimizers: A list of neural network schedulers to use. :param WeightingInterface weighting: The loss weighting to use. :param bool use_lt: Using LabelTensors as input during training. """ if not isinstance(models, (list, tuple)) or len(models) < 2: raise ValueError( 'models should be list[torch.nn.Module] or ' 'tuple[torch.nn.Module] with len greater than ' 'one.' ) if any(opt is None for opt in optimizers): optimizers = [ self.default_torch_optimizer() if opt is None else opt for opt in optimizers ] if any(sched is None for sched in schedulers): schedulers = [ self.default_torch_scheduler() if sched is None else sched for sched in schedulers ] super().__init__(problem=problem, use_lt=use_lt, weighting=weighting) # check consistency of models argument and encapsulate in list check_consistency(models, torch.nn.Module) # check scheduler consistency and encapsulate in list check_consistency(schedulers, Scheduler) # check optimizer consistency and encapsulate in list check_consistency(optimizers, Optimizer) # check length consistency optimizers if len(models) != len(optimizers): raise ValueError( "You must define one optimizer for each model." f"Got {len(models)} models, and {len(optimizers)}" " optimizers." ) # initialize the model self._pina_models = torch.nn.ModuleList(models) self._pina_optimizers = optimizers self._pina_schedulers = schedulers def configure_optimizers(self): """Optimizer configuration for the solver. :return: The optimizers and the schedulers :rtype: tuple(list, list) """ for optimizer, scheduler, model in zip(self.optimizers, self.schedulers, self.models): optimizer.hook(model.parameters()) scheduler.hook(optimizer) return ( [optimizer.instance for optimizer in self.optimizers], [scheduler.instance for scheduler in self.schedulers] ) def _compile_model(self): for i, model in enumerate(self._pina_models): if not isinstance(model, torch.nn.ModuleDict): self._pina_models[i] = self._perform_compilation(model) @property def models(self): """ The torch model.""" return self._pina_models @property def optimizers(self): """ The torch model.""" return self._pina_optimizers @property def schedulers(self): """ The torch model.""" return self._pina_schedulers