import torch import pytest from pina.problem import AbstractProblem from pina import Condition, LabelTensor from pina.solvers import ReducedOrderModelSolver from pina.trainer import Trainer from pina.model import FeedForward from pina.loss import LpLoss class NeuralOperatorProblem(AbstractProblem): input_variables = ['u_0', 'u_1'] output_variables = [f'u_{i}' for i in range(100)] conditions = {'data' : Condition(input_points= LabelTensor(torch.rand(10, 2), input_variables), output_points= LabelTensor(torch.rand(10, 100), output_variables))} # make the problem + extra feats class AE(torch.nn.Module): def __init__(self, input_dimensions, rank): super().__init__() self.encode = FeedForward(input_dimensions, rank, layers=[input_dimensions//4]) self.decode = FeedForward(rank, input_dimensions, layers=[input_dimensions//4]) class AE_missing_encode(torch.nn.Module): def __init__(self, input_dimensions, rank): super().__init__() self.encode = FeedForward(input_dimensions, rank, layers=[input_dimensions//4]) class AE_missing_decode(torch.nn.Module): def __init__(self, input_dimensions, rank): super().__init__() self.decode = FeedForward(rank, input_dimensions, layers=[input_dimensions//4]) rank = 10 problem = NeuralOperatorProblem() interpolation_net = FeedForward(len(problem.input_variables), rank) reduction_net = AE(len(problem.output_variables), rank) def test_constructor(): ReducedOrderModelSolver(problem=problem,reduction_network=reduction_net, interpolation_network=interpolation_net) with pytest.raises(SyntaxError): ReducedOrderModelSolver(problem=problem, reduction_network=AE_missing_encode( len(problem.output_variables), rank), interpolation_network=interpolation_net) ReducedOrderModelSolver(problem=problem, reduction_network=AE_missing_decode( len(problem.output_variables), rank), interpolation_network=interpolation_net) def test_train_cpu(): solver = ReducedOrderModelSolver(problem = problem,reduction_network=reduction_net, interpolation_network=interpolation_net, loss=LpLoss()) trainer = Trainer(solver=solver, max_epochs=3, accelerator='cpu', batch_size=20) trainer.train() def test_train_restore(): tmpdir = "tests/tmp_restore" solver = ReducedOrderModelSolver(problem=problem, reduction_network=reduction_net, interpolation_network=interpolation_net, loss=LpLoss()) trainer = Trainer(solver=solver, max_epochs=5, accelerator='cpu', default_root_dir=tmpdir) trainer.train() ntrainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu') t = ntrainer.train( ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt') import shutil shutil.rmtree(tmpdir) def test_train_load(): tmpdir = "tests/tmp_load" solver = ReducedOrderModelSolver(problem=problem, reduction_network=reduction_net, interpolation_network=interpolation_net, loss=LpLoss()) trainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu', default_root_dir=tmpdir) trainer.train() new_solver = ReducedOrderModelSolver.load_from_checkpoint( f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt', problem = problem,reduction_network=reduction_net, interpolation_network=interpolation_net) test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables) assert new_solver.forward(test_pts).shape == (20, 100) assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape torch.testing.assert_close( new_solver.forward(test_pts), solver.forward(test_pts)) import shutil shutil.rmtree(tmpdir)