Implementation of DataLoader and DataModule (#383)
Refactoring for 0.2 * Data module, data loader and dataset * Refactor LabelTensor * Refactor solvers Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
dd43c8304c
commit
a27bd35443
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import pytorch_lightning
|
||||
import lightning
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
from ..optim import Optimizer, Scheduler
|
||||
@@ -10,7 +10,8 @@ import torch
|
||||
import sys
|
||||
|
||||
|
||||
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver base class. This class inherits is a wrapper of
|
||||
LightningModule class, inheriting all the
|
||||
@@ -83,7 +84,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
" optimizers.")
|
||||
|
||||
# extra features handling
|
||||
|
||||
self._pina_models = models
|
||||
self._pina_optimizers = optimizers
|
||||
self._pina_schedulers = schedulers
|
||||
@@ -94,7 +94,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -138,8 +138,16 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
TODO
|
||||
"""
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(self.accepted_condition_types).issubset(
|
||||
condition.condition_type):
|
||||
if not set(condition.condition_type).issubset(
|
||||
set(self.accepted_condition_types)):
|
||||
raise ValueError(
|
||||
f'{self.__name__} support only dose not support condition '
|
||||
f'{self.__name__} dose not support condition '
|
||||
f'{condition.condition_type}')
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
# Assuming batch is your custom Batch object
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]['input_points'])
|
||||
return batch_size
|
||||
Reference in New Issue
Block a user