Documentation for v0.1 version (#199)

* Adding Equations, solving typos
* improve _code.rst
* the team rst and restuctore index.rst
* fixing errors

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:39:00 +01:00
committed by Nicola Demo
parent 3f9305d475
commit 8b7b61b3bd
144 changed files with 2741 additions and 1766 deletions

View File

@@ -1,8 +1,9 @@
""" Module """
""" Module for SystemEquation. """
import torch
from .equation import Equation
from ..utils import check_consistency
class SystemEquation(Equation):
def __init__(self, list_equation, reduction='mean'):
@@ -14,7 +15,7 @@ class SystemEquation(Equation):
A ``SystemEquation`` is specified by a list of
equations.
:param callable equation: A ``torch`` callable equation to
:param Callable equation: A ``torch`` callable equation to
evaluate the residual
:param str reduction: Specifies the reduction to apply to the output:
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
@@ -28,7 +29,7 @@ class SystemEquation(Equation):
# equations definition
self.equations = []
for _, equation in enumerate(list_equation):
for _, equation in enumerate(list_equation):
self.equations.append(Equation(equation))
# possible reduction
@@ -39,7 +40,8 @@ class SystemEquation(Equation):
elif (reduction == 'none') or callable(reduction):
self.reduction = reduction
else:
raise NotImplementedError('Only mean and sum reductions implemented.')
raise NotImplementedError(
'Only mean and sum reductions implemented.')
def residual(self, input_, output_):
"""
@@ -52,12 +54,10 @@ class SystemEquation(Equation):
aggregated by the ``reduction`` defined in the ``__init__``.
:rtype: LabelTensor
"""
residual = torch.hstack([
equation.residual(input_, output_)
for equation in self.equations
])
residual = torch.hstack(
[equation.residual(input_, output_) for equation in self.equations])
if self.reduction == 'none':
return residual
return self.reduction(residual, dim=-1)
return self.reduction(residual, dim=-1)