Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -2,9 +2,9 @@ import torch
from pina.model import FNO
output_channels = 5
batch_size = 15
resolution = [30, 40, 50]
lifting_dim = 128
batch_size = 4
resolution = [4, 6, 8]
lifting_dim = 24
def test_constructor():

View File

@@ -1,49 +0,0 @@
import torch
import pytest
from pina.model.network import Network
from pina.model import FeedForward
from pina import LabelTensor
data = torch.rand((20, 3))
data_lt = LabelTensor(data, ['x', 'y', 'z'])
input_dim = 3
output_dim = 4
torchmodel = FeedForward(input_dim, output_dim)
extra_feat = []
def test_constructor():
Network(model=torchmodel,
input_variables=['x', 'y', 'z'],
output_variables=['a', 'b', 'c', 'd'],
extra_features=None)
def test_forward():
net = Network(model=torchmodel,
input_variables=['x', 'y', 'z'],
output_variables=['a', 'b', 'c', 'd'],
extra_features=None)
out = net.torchmodel(data)
out_lt = net(data_lt)
assert isinstance(out, torch.Tensor)
assert isinstance(out_lt, LabelTensor)
assert out.shape == (20, 4)
assert out_lt.shape == (20, 4)
assert torch.allclose(out_lt, out)
assert out_lt.labels == ['a', 'b', 'c', 'd']
with pytest.raises(AssertionError):
net(data)
def test_backward():
net = Network(model=torchmodel,
input_variables=['x', 'y', 'z'],
output_variables=['a', 'b', 'c', 'd'],
extra_features=None)
data = torch.rand((20, 3))
data.requires_grad = True
out = net.torchmodel(data)
l = torch.mean(out)
l.backward()
assert data._grad.shape == torch.Size([20, 3])