Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Solver module. """
|
||||
"""Solver module."""
|
||||
|
||||
import lightning
|
||||
import torch
|
||||
@@ -18,10 +18,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
SolverInterface base class. This class is a wrapper of LightningModule.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
weighting,
|
||||
use_lt):
|
||||
def __init__(self, problem, weighting, use_lt):
|
||||
"""
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
@@ -82,7 +79,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
losses = self.optimization_cycle(batch)
|
||||
for name, value in losses.items():
|
||||
self.store_log(f'{name}_loss', value.item(), self.get_batch_size(batch))
|
||||
self.store_log(
|
||||
f"{name}_loss", value.item(), self.get_batch_size(batch)
|
||||
)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
return loss
|
||||
|
||||
@@ -96,7 +95,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
self.store_log('train_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("train_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
@@ -107,7 +106,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
:type batch: tuple
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
self.store_log('val_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
|
||||
def test_step(self, batch):
|
||||
"""
|
||||
@@ -117,14 +116,15 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
:type batch: tuple
|
||||
"""
|
||||
loss = self._optimization_cycle(batch=batch)
|
||||
self.store_log('test_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
|
||||
def store_log(self, name, value, batch_size):
|
||||
self.log(name=name,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
**self.trainer.logging_kwargs
|
||||
)
|
||||
self.log(
|
||||
name=name,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
**self.trainer.logging_kwargs,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -172,7 +172,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
# assuming batch is a custom Batch object
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]['input_points'])
|
||||
batch_size += len(data[1]["input_points"])
|
||||
return batch_size
|
||||
|
||||
@staticmethod
|
||||
@@ -203,8 +203,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
def _check_already_compiled(self):
|
||||
models = self._pina_models
|
||||
if len(models) == 1 and isinstance(self._pina_models[0],
|
||||
torch.nn.ModuleDict):
|
||||
if len(models) == 1 and isinstance(
|
||||
self._pina_models[0], torch.nn.ModuleDict
|
||||
):
|
||||
models = list(self._pina_models.values())
|
||||
for model in models:
|
||||
if not isinstance(model, (OptimizedModule, torch.nn.ModuleDict)):
|
||||
@@ -225,13 +226,15 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
|
||||
class SingleSolverInterface(SolverInterface):
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
use_lt=True):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
use_lt=True,
|
||||
):
|
||||
"""
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
@@ -248,9 +251,7 @@ class SingleSolverInterface(SolverInterface):
|
||||
if scheduler is None:
|
||||
scheduler = self.default_torch_scheduler()
|
||||
|
||||
super().__init__(problem=problem,
|
||||
use_lt=use_lt,
|
||||
weighting=weighting)
|
||||
super().__init__(problem=problem, use_lt=use_lt, weighting=weighting)
|
||||
|
||||
# check consistency of models argument and encapsulate in list
|
||||
check_consistency(model, torch.nn.Module)
|
||||
@@ -284,10 +285,7 @@ class SingleSolverInterface(SolverInterface):
|
||||
"""
|
||||
self.optimizer.hook(self.model.parameters())
|
||||
self.scheduler.hook(self.optimizer)
|
||||
return (
|
||||
[self.optimizer.instance],
|
||||
[self.scheduler.instance]
|
||||
)
|
||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||
|
||||
def _compile_model(self):
|
||||
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
|
||||
@@ -330,13 +328,15 @@ class MultiSolverInterface(SolverInterface):
|
||||
SolverInterface class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
models,
|
||||
optimizers=None,
|
||||
schedulers=None,
|
||||
weighting=None,
|
||||
use_lt=True):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
models,
|
||||
optimizers=None,
|
||||
schedulers=None,
|
||||
weighting=None,
|
||||
use_lt=True,
|
||||
):
|
||||
"""
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
@@ -351,9 +351,9 @@ class MultiSolverInterface(SolverInterface):
|
||||
"""
|
||||
if not isinstance(models, (list, tuple)) or len(models) < 2:
|
||||
raise ValueError(
|
||||
'models should be list[torch.nn.Module] or '
|
||||
'tuple[torch.nn.Module] with len greater than '
|
||||
'one.'
|
||||
"models should be list[torch.nn.Module] or "
|
||||
"tuple[torch.nn.Module] with len greater than "
|
||||
"one."
|
||||
)
|
||||
|
||||
if any(opt is None for opt in optimizers):
|
||||
@@ -368,9 +368,7 @@ class MultiSolverInterface(SolverInterface):
|
||||
for sched in schedulers
|
||||
]
|
||||
|
||||
super().__init__(problem=problem,
|
||||
use_lt=use_lt,
|
||||
weighting=weighting)
|
||||
super().__init__(problem=problem, use_lt=use_lt, weighting=weighting)
|
||||
|
||||
# check consistency of models argument and encapsulate in list
|
||||
check_consistency(models, torch.nn.Module)
|
||||
@@ -400,15 +398,15 @@ class MultiSolverInterface(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
for optimizer, scheduler, model in zip(self.optimizers,
|
||||
self.schedulers,
|
||||
self.models):
|
||||
for optimizer, scheduler, model in zip(
|
||||
self.optimizers, self.schedulers, self.models
|
||||
):
|
||||
optimizer.hook(model.parameters())
|
||||
scheduler.hook(optimizer)
|
||||
|
||||
return (
|
||||
[optimizer.instance for optimizer in self.optimizers],
|
||||
[scheduler.instance for scheduler in self.schedulers]
|
||||
[scheduler.instance for scheduler in self.schedulers],
|
||||
)
|
||||
|
||||
def _compile_model(self):
|
||||
|
||||
Reference in New Issue
Block a user