Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,5 @@
|
||||
""" Module for SupervisedSolver """
|
||||
"""Module for SupervisedSolver"""
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from .solver import SingleSolverInterface
|
||||
@@ -38,14 +39,16 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
|
||||
accepted_conditions_types = InputOutputPointsCondition
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
use_lt=True):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
use_lt=True,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
@@ -61,16 +64,19 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
use_lt=use_lt)
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
use_lt=use_lt,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
|
||||
subclass=False)
|
||||
check_consistency(
|
||||
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
|
||||
)
|
||||
self._loss = loss
|
||||
|
||||
def optimization_cycle(self, batch):
|
||||
@@ -79,7 +85,7 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
in the given batch.
|
||||
|
||||
:param batch: A batch of data, where each element is a tuple containing
|
||||
a condition name and a dictionary of points.
|
||||
a condition name and a dictionary of points.
|
||||
:type batch: list of tuples (str, dict)
|
||||
:return: The computed loss for the all conditions in the batch,
|
||||
cast to a subclass of `torch.Tensor`. It should return a dict
|
||||
@@ -88,9 +94,13 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
input_pts, output_pts = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
)
|
||||
condition_loss[condition_name] = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts)
|
||||
input_pts=input_pts, output_pts=output_pts
|
||||
)
|
||||
return condition_loss
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@@ -114,4 +124,4 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
"""
|
||||
Loss for training.
|
||||
"""
|
||||
return self._loss
|
||||
return self._loss
|
||||
|
||||
Reference in New Issue
Block a user