diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index 5af2cc8..00d6dfb 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -1,7 +1,7 @@ """PINA Callbacks Implementations""" import torch -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from ..label_tensor import LabelTensor from ..utils import check_consistency diff --git a/pina/callbacks/optimizer_callbacks.py b/pina/callbacks/optimizer_callbacks.py index c11db88..6905ebf 100644 --- a/pina/callbacks/optimizer_callbacks.py +++ b/pina/callbacks/optimizer_callbacks.py @@ -1,6 +1,6 @@ """PINA Callbacks Implementations""" -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback import torch from ..utils import check_consistency diff --git a/pina/callbacks/processing_callbacks.py b/pina/callbacks/processing_callbacks.py index a70218e..e65506c 100644 --- a/pina/callbacks/processing_callbacks.py +++ b/pina/callbacks/processing_callbacks.py @@ -1,11 +1,11 @@ """PINA Callbacks Implementations""" -from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.core.module import LightningModule +from lightning.pytorch.trainer.trainer import Trainer import torch import copy -from pytorch_lightning.callbacks import Callback, TQDMProgressBar +from lightning.pytorch.callbacks import Callback, TQDMProgressBar from lightning.pytorch.callbacks.progress.progress_bar import ( get_standard_metrics, ) diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index c7f5f66..99fec09 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,6 +1,5 @@ """ Module for SupervisedSolver """ import torch -from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface @@ -145,7 +144,7 @@ class SupervisedSolver(SolverInterface): self.log('val_loss', loss, prog_bar=True, logger=True, batch_size=self.get_batch_size(batch), sync_dist=True) - def test_step(self, batch, batch_idx) -> STEP_OUTPUT: + def test_step(self, batch, batch_idx): """ Solver test step. """