Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -262,7 +262,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
for (
condition_name,
tensor,
) in self.trainer.data_module.train_dataset.input_points.items():
) in self.trainer.data_module.train_dataset.input.items():
self.weights_dict[condition_name].sa_weights.data = torch.rand(
(tensor.shape[0], 1), device=device
)