Update import from lightning.pytorch (#409)

* update import
* Remove unnecessary import return type

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
This commit is contained in:
Giovanni Canali
2025-01-22 10:48:42 +01:00
committed by Nicola Demo
parent d51de028bd
commit 4bec5bfc9a
4 changed files with 6 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
import torch import torch
from pytorch_lightning.callbacks import Callback from lightning.pytorch.callbacks import Callback
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..utils import check_consistency from ..utils import check_consistency

View File

@@ -1,6 +1,6 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
from pytorch_lightning.callbacks import Callback from lightning.pytorch.callbacks import Callback
import torch import torch
from ..utils import check_consistency from ..utils import check_consistency

View File

@@ -1,11 +1,11 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
from pytorch_lightning.core.module import LightningModule from lightning.pytorch.core.module import LightningModule
from pytorch_lightning.trainer.trainer import Trainer from lightning.pytorch.trainer.trainer import Trainer
import torch import torch
import copy import copy
from pytorch_lightning.callbacks import Callback, TQDMProgressBar from lightning.pytorch.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import ( from lightning.pytorch.callbacks.progress.progress_bar import (
get_standard_metrics, get_standard_metrics,
) )

View File

@@ -1,6 +1,5 @@
""" Module for SupervisedSolver """ """ Module for SupervisedSolver """
import torch import torch
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ..optim import TorchOptimizer, TorchScheduler from ..optim import TorchOptimizer, TorchScheduler
from .solver import SolverInterface from .solver import SolverInterface
@@ -145,7 +144,7 @@ class SupervisedSolver(SolverInterface):
self.log('val_loss', loss, prog_bar=True, logger=True, self.log('val_loss', loss, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True) batch_size=self.get_batch_size(batch), sync_dist=True)
def test_step(self, batch, batch_idx) -> STEP_OUTPUT: def test_step(self, batch, batch_idx):
""" """
Solver test step. Solver test step.
""" """