Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -9,8 +9,8 @@ from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ...condition import (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
@@ -28,8 +28,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
|
||||
accepted_conditions_types = (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
@@ -138,16 +138,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if "output_points" not in points:
|
||||
input_pts = points["input_points"]
|
||||
if "target" not in points:
|
||||
input_pts = points["input"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points["input_points"]
|
||||
output_pts = points["output_points"]
|
||||
input_pts = points["input"]
|
||||
output_pts = points["target"]
|
||||
loss = self.loss_data(
|
||||
input_pts=input_pts.requires_grad_(), output_pts=output_pts
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user