Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,4 +1,5 @@
""" Module for SupervisedSolver """
"""Module for SupervisedSolver"""
import torch
from torch.nn.modules.loss import _Loss
from .solver import SingleSolverInterface
@@ -38,14 +39,16 @@ class SupervisedSolver(SingleSolverInterface):
accepted_conditions_types = InputOutputPointsCondition
def __init__(self,
problem,
model,
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=True):
def __init__(
self,
problem,
model,
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=True,
):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use.
@@ -61,16 +64,19 @@ class SupervisedSolver(SingleSolverInterface):
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt)
super().__init__(
model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt,
)
# check consistency
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
subclass=False)
check_consistency(
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
)
self._loss = loss
def optimization_cycle(self, batch):
@@ -79,7 +85,7 @@ class SupervisedSolver(SingleSolverInterface):
in the given batch.
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
@@ -88,9 +94,13 @@ class SupervisedSolver(SingleSolverInterface):
"""
condition_loss = {}
for condition_name, points in batch:
input_pts, output_pts = points['input_points'], points['output_points']
input_pts, output_pts = (
points["input_points"],
points["output_points"],
)
condition_loss[condition_name] = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
input_pts=input_pts, output_pts=output_pts
)
return condition_loss
def loss_data(self, input_pts, output_pts):
@@ -114,4 +124,4 @@ class SupervisedSolver(SingleSolverInterface):
"""
Loss for training.
"""
return self._loss
return self._loss