supervised working

This commit is contained in:
Nicola Demo
2024-08-08 16:19:52 +02:00
parent 5245a0b68c
commit 9d9c2aa23e
61 changed files with 375 additions and 262 deletions

View File

@@ -1,35 +0,0 @@
from . import ConditionInterface
class InputOutputCondition(ConditionInterface):
"""
Condition for input/output data.
"""
__slots__ = ["input_points", "output_points"]
def __init__(self, input_points, output_points):
"""
Constructor for the `InputOutputCondition` class.
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
def residual(self, model):
"""
Compute the residual of the condition.
"""
return self.batch_residual(model, self.input_points, self.output_points)
@staticmethod
def batch_residual(model, input_points, output_points):
"""
Compute the residual of the condition for a single batch. Input and
output points are provided as arguments.
:param torch.nn.Module model: The model to evaluate the condition.
:param torch.Tensor input_points: The input points.
:param torch.Tensor output_points: The output points.
"""
return output_points - model(input_points)