Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver
This commit is contained in:
committed by
Nicola Demo
parent
b9753c34b2
commit
c9304fb9bb
@@ -2,9 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||
from ..optim import TorchOptimizer, TorchScheduler
|
||||
from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
@@ -39,14 +37,17 @@ class SupervisedSolver(SolverInterface):
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
accepted_condition_types = ['supervised']
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
@@ -57,11 +58,8 @@ class SupervisedSolver(SolverInterface):
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param float lr: The learning rate; default is 0.001.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
@@ -74,18 +72,19 @@ class SupervisedSolver(SolverInterface):
|
||||
torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
models=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
self._loss = loss
|
||||
self._model = self._pina_model[0]
|
||||
self._optimizer = self._pina_optimizer[0]
|
||||
self._scheduler = self._pina_scheduler[0]
|
||||
self._model = self._pina_models[0]
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -97,12 +96,7 @@ class SupervisedSolver(SolverInterface):
|
||||
|
||||
output = self._model(x)
|
||||
|
||||
output.labels = {
|
||||
1: {
|
||||
"name": "output",
|
||||
"dof": self.problem.output_variables
|
||||
}
|
||||
}
|
||||
output.labels = self.problem.output_variables
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
@@ -128,16 +122,14 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
condition_idx = batch.condition
|
||||
condition_idx = batch.supervised.condition_indices
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch.input
|
||||
out = batch.output
|
||||
|
||||
pts = batch.supervised.input_points
|
||||
out = batch.supervised.output_points
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
|
||||
@@ -167,8 +159,8 @@ class SupervisedSolver(SolverInterface):
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
:param LabelTensor input_pts: The input to the neural networks.
|
||||
:param LabelTensor output_pts: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
@@ -181,7 +173,7 @@ class SupervisedSolver(SolverInterface):
|
||||
Scheduler for training.
|
||||
"""
|
||||
return self._scheduler
|
||||
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user