From 4bec5bfc9a0c8cbdba298c3f338adba48efb1881 Mon Sep 17 00:00:00 2001 From: Giovanni Canali <115086358+gc031298@users.noreply.github.com> Date: Wed, 22 Jan 2025 10:48:42 +0100 Subject: [PATCH] Update import from lightning.pytorch (#409) * update import * Remove unnecessary import return type --------- Co-authored-by: Filippo Olivo --- pina/callbacks/adaptive_refinment_callbacks.py | 2 +- pina/callbacks/optimizer_callbacks.py | 2 +- pina/callbacks/processing_callbacks.py | 6 +++--- pina/solvers/supervised.py | 3 +-- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index 5af2cc8..00d6dfb 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -1,7 +1,7 @@ """PINA Callbacks Implementations""" import torch -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from ..label_tensor import LabelTensor from ..utils import check_consistency diff --git a/pina/callbacks/optimizer_callbacks.py b/pina/callbacks/optimizer_callbacks.py index c11db88..6905ebf 100644 --- a/pina/callbacks/optimizer_callbacks.py +++ b/pina/callbacks/optimizer_callbacks.py @@ -1,6 +1,6 @@ """PINA Callbacks Implementations""" -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback import torch from ..utils import check_consistency diff --git a/pina/callbacks/processing_callbacks.py b/pina/callbacks/processing_callbacks.py index a70218e..e65506c 100644 --- a/pina/callbacks/processing_callbacks.py +++ b/pina/callbacks/processing_callbacks.py @@ -1,11 +1,11 @@ """PINA Callbacks Implementations""" -from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.core.module import LightningModule +from lightning.pytorch.trainer.trainer import Trainer import torch import copy -from pytorch_lightning.callbacks import Callback, TQDMProgressBar +from lightning.pytorch.callbacks import Callback, TQDMProgressBar from lightning.pytorch.callbacks.progress.progress_bar import ( get_standard_metrics, ) diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index c7f5f66..99fec09 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,6 +1,5 @@ """ Module for SupervisedSolver """ import torch -from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface @@ -145,7 +144,7 @@ class SupervisedSolver(SolverInterface): self.log('val_loss', loss, prog_bar=True, logger=True, batch_size=self.get_batch_size(batch), sync_dist=True) - def test_step(self, batch, batch_idx) -> STEP_OUTPUT: + def test_step(self, batch, batch_idx): """ Solver test step. """