Files
PINA/tests/test_weighting/test_scalar_weighting.py
giovanni 96402baf20 weighting refactory
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
2025-09-08 14:46:33 +02:00

40 lines
1.2 KiB
Python

import pytest
import torch
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.loss import ScalarWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem()
problem.discretise_domain(50)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
condition_names = problem.conditions.keys()
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)
def test_constructor(weights):
ScalarWeighting(weights=weights)
# Should fail if weights are not a scalar
with pytest.raises(ValueError):
ScalarWeighting(weights="invalid")
# Should fail if weights are not a dictionary
with pytest.raises(ValueError):
ScalarWeighting(weights=[1, 2, 3])
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)
def test_train_aggregation(weights):
weighting = ScalarWeighting(weights=weights)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()