Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,4 +1,4 @@
""" Solver module. """
"""Solver module."""
import lightning
import torch
@@ -18,10 +18,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
SolverInterface base class. This class is a wrapper of LightningModule.
"""
def __init__(self,
problem,
weighting,
use_lt):
def __init__(self, problem, weighting, use_lt):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
@@ -82,7 +79,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
losses = self.optimization_cycle(batch)
for name, value in losses.items():
self.store_log(f'{name}_loss', value.item(), self.get_batch_size(batch))
self.store_log(
f"{name}_loss", value.item(), self.get_batch_size(batch)
)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss
@@ -96,7 +95,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
:rtype: LabelTensor
"""
loss = self._optimization_cycle(batch=batch)
self.store_log('train_loss', loss, self.get_batch_size(batch))
self.store_log("train_loss", loss, self.get_batch_size(batch))
return loss
def validation_step(self, batch):
@@ -107,7 +106,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
:type batch: tuple
"""
loss = self._optimization_cycle(batch=batch)
self.store_log('val_loss', loss, self.get_batch_size(batch))
self.store_log("val_loss", loss, self.get_batch_size(batch))
def test_step(self, batch):
"""
@@ -117,14 +116,15 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
:type batch: tuple
"""
loss = self._optimization_cycle(batch=batch)
self.store_log('test_loss', loss, self.get_batch_size(batch))
self.store_log("test_loss", loss, self.get_batch_size(batch))
def store_log(self, name, value, batch_size):
self.log(name=name,
value=value,
batch_size=batch_size,
**self.trainer.logging_kwargs
)
self.log(
name=name,
value=value,
batch_size=batch_size,
**self.trainer.logging_kwargs,
)
@abstractmethod
def forward(self, *args, **kwargs):
@@ -172,7 +172,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
# assuming batch is a custom Batch object
batch_size = 0
for data in batch:
batch_size += len(data[1]['input_points'])
batch_size += len(data[1]["input_points"])
return batch_size
@staticmethod
@@ -203,8 +203,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def _check_already_compiled(self):
models = self._pina_models
if len(models) == 1 and isinstance(self._pina_models[0],
torch.nn.ModuleDict):
if len(models) == 1 and isinstance(
self._pina_models[0], torch.nn.ModuleDict
):
models = list(self._pina_models.values())
for model in models:
if not isinstance(model, (OptimizedModule, torch.nn.ModuleDict)):
@@ -225,13 +226,15 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
class SingleSolverInterface(SolverInterface):
def __init__(self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=True):
def __init__(
self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=True,
):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
@@ -248,9 +251,7 @@ class SingleSolverInterface(SolverInterface):
if scheduler is None:
scheduler = self.default_torch_scheduler()
super().__init__(problem=problem,
use_lt=use_lt,
weighting=weighting)
super().__init__(problem=problem, use_lt=use_lt, weighting=weighting)
# check consistency of models argument and encapsulate in list
check_consistency(model, torch.nn.Module)
@@ -284,10 +285,7 @@ class SingleSolverInterface(SolverInterface):
"""
self.optimizer.hook(self.model.parameters())
self.scheduler.hook(self.optimizer)
return (
[self.optimizer.instance],
[self.scheduler.instance]
)
return ([self.optimizer.instance], [self.scheduler.instance])
def _compile_model(self):
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
@@ -330,13 +328,15 @@ class MultiSolverInterface(SolverInterface):
SolverInterface class
"""
def __init__(self,
problem,
models,
optimizers=None,
schedulers=None,
weighting=None,
use_lt=True):
def __init__(
self,
problem,
models,
optimizers=None,
schedulers=None,
weighting=None,
use_lt=True,
):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
@@ -351,9 +351,9 @@ class MultiSolverInterface(SolverInterface):
"""
if not isinstance(models, (list, tuple)) or len(models) < 2:
raise ValueError(
'models should be list[torch.nn.Module] or '
'tuple[torch.nn.Module] with len greater than '
'one.'
"models should be list[torch.nn.Module] or "
"tuple[torch.nn.Module] with len greater than "
"one."
)
if any(opt is None for opt in optimizers):
@@ -368,9 +368,7 @@ class MultiSolverInterface(SolverInterface):
for sched in schedulers
]
super().__init__(problem=problem,
use_lt=use_lt,
weighting=weighting)
super().__init__(problem=problem, use_lt=use_lt, weighting=weighting)
# check consistency of models argument and encapsulate in list
check_consistency(models, torch.nn.Module)
@@ -400,15 +398,15 @@ class MultiSolverInterface(SolverInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
for optimizer, scheduler, model in zip(self.optimizers,
self.schedulers,
self.models):
for optimizer, scheduler, model in zip(
self.optimizers, self.schedulers, self.models
):
optimizer.hook(model.parameters())
scheduler.hook(optimizer)
return (
[optimizer.instance for optimizer in self.optimizers],
[scheduler.instance for scheduler in self.schedulers]
[scheduler.instance for scheduler in self.schedulers],
)
def _compile_model(self):