Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -0,0 +1,42 @@
import pytest
import torch
from pina import Trainer
from pina.solvers import PINN
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import ScalarWeighting
problem = Poisson2DSquareProblem()
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
condition_names = problem.conditions.keys()
print(problem.conditions.keys())
@pytest.mark.parametrize("weights",
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
def test_constructor(weights):
ScalarWeighting(weights=weights)
@pytest.mark.parametrize("weights", ['a', [1,2,3]])
def test_wrong_constructor(weights):
with pytest.raises(ValueError):
ScalarWeighting(weights=weights)
@pytest.mark.parametrize("weights",
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
def test_aggregate(weights):
weighting = ScalarWeighting(weights=weights)
losses = dict(zip(condition_names, [torch.randn(1) for _ in range(len(condition_names))]))
weighting.aggregate(losses=losses)
@pytest.mark.parametrize("weights",
[1, 1., dict(zip(condition_names, [1]*len(condition_names)))])
def test_train_aggregation(weights):
weighting = ScalarWeighting(weights=weights)
problem.discretise_domain(50)
solver = PINN(
problem=problem,
model=model,
weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator='cpu')
trainer.train()