🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
""" Module for SupervisedSolver """
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
@@ -18,7 +22,7 @@ from torch.nn.modules.loss import _Loss
|
||||
class SupervisedSolver(SolverInterface):
|
||||
"""
|
||||
SupervisedSolver solver class. This class implements a SupervisedSolver,
|
||||
using a user specified ``model`` to solve a specific ``problem``.
|
||||
using a user specified ``model`` to solve a specific ``problem``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -28,14 +32,11 @@ class SupervisedSolver(SolverInterface):
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={'lr': 0.001},
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={
|
||||
"factor": 1,
|
||||
"total_iters": 0
|
||||
},
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
'''
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
@@ -49,12 +50,14 @@ class SupervisedSolver(SolverInterface):
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
'''
|
||||
super().__init__(models=[model],
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
extra_features=extra_features)
|
||||
"""
|
||||
super().__init__(
|
||||
models=[model],
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
extra_features=extra_features,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
@@ -69,7 +72,7 @@ class SupervisedSolver(SolverInterface):
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
@@ -95,32 +98,39 @@ class SupervisedSolver(SolverInterface):
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_idx = batch['condition']
|
||||
condition_idx = batch["condition"]
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
else:
|
||||
condition_name = dataloader.loaders.condition_names[condition_id]
|
||||
condition_name = dataloader.loaders.condition_names[
|
||||
condition_id
|
||||
]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
out = batch['output']
|
||||
pts = batch["pts"]
|
||||
out = batch["output"]
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
|
||||
# for data driven mode
|
||||
if not hasattr(condition, 'output_points'):
|
||||
raise NotImplementedError('Supervised solver works only in data-driven mode.')
|
||||
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
"Supervised solver works only in data-driven mode."
|
||||
)
|
||||
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
loss = self.loss(self.forward(input_pts), output_pts) * condition.data_weight
|
||||
loss = (
|
||||
self.loss(self.forward(input_pts), output_pts)
|
||||
* condition.data_weight
|
||||
)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user