Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -9,8 +9,8 @@ from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
@@ -28,8 +28,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
accepted_conditions_types = (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
@@ -138,16 +138,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if "output_points" not in points:
input_pts = points["input_points"]
if "target" not in points:
input_pts = points["input"]
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(), condition.equation
)
# if data are passed
else:
input_pts = points["input_points"]
output_pts = points["output_points"]
input_pts = points["input"]
output_pts = points["target"]
loss = self.loss_data(
input_pts=input_pts.requires_grad_(), output_pts=output_pts
)

View File

@@ -262,7 +262,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
for (
condition_name,
tensor,
) in self.trainer.data_module.train_dataset.input_points.items():
) in self.trainer.data_module.train_dataset.input.items():
self.weights_dict[condition_name].sa_weights.data = torch.rand(
(tensor.shape[0], 1), device=device
)