minor fix

This commit is contained in:
Dario Coscia
2023-09-19 15:13:50 +02:00
committed by Nicola Demo
parent 4d1187898f
commit 1936133ad5
5 changed files with 222 additions and 57 deletions

View File

@@ -0,0 +1,81 @@
from pina.callbacks import SwitchOptimizer
import torch
import pytest
from pina.problem import SpatialProblem
from pina.operators import laplacian
from pina.geometry import CartesianDomain
from pina import Condition, LabelTensor, PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.equation.equation import Equation
from pina.equation.equation_factory import FixedValue
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
conditions = {
'gamma1': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 1}),
equation=FixedValue(0.0)),
'gamma2': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 0}),
equation=FixedValue(0.0)),
'gamma3': Condition(
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'gamma4': Condition(
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_)
}
# make the problem
poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
model = FeedForward(len(poisson_problem.input_variables),len(poisson_problem.output_variables))
# make the solver
solver = PINN(problem=poisson_problem, model=model)
def test_switch_optimizer_constructor():
SwitchOptimizer(new_optimizers=torch.optim.Adam,
new_optimizers_kwargs={'lr':0.01},
epoch_switch=10)
with pytest.raises(ValueError):
SwitchOptimizer(new_optimizers=[torch.optim.Adam, torch.optim.Adam],
new_optimizers_kwargs=[{'lr':0.01}],
epoch_switch=10)
def test_switch_optimizer_routine():
# make the trainer
trainer = Trainer(solver=solver, callbacks=[SwitchOptimizer(new_optimizers=torch.optim.LBFGS,
new_optimizers_kwargs={'lr':0.01},
epoch_switch=3)], max_epochs=5)
trainer.train()