Rename nabla -> laplacian

This commit is contained in:
Pasquale Africa
2023-07-20 08:15:13 +00:00
committed by Nicola Demo
parent 92e0e4920b
commit 6c627c70e3
16 changed files with 62 additions and 63 deletions

View File

@@ -3,7 +3,7 @@ import torch
from pina import Span, Condition
from pina.problem import SpatialProblem, ParametricProblem
from pina.operators import grad, nabla
from pina.operators import grad, laplacian
# ===================================================== #
# #
@@ -42,11 +42,11 @@ class ParametricEllipticOptimalControl(SpatialProblem, ParametricProblem):
# equation terms as in https://arxiv.org/pdf/2110.13530.pdf
def term1(input_, output_):
laplace_p = nabla(output_, input_, components=['p'], d=['x1', 'x2'])
laplace_p = laplacian(output_, input_, components=['p'], d=['x1', 'x2'])
return output_.extract(['y']) - input_.extract(['mu']) - laplace_p
def term2(input_, output_):
laplace_y = nabla(output_, input_, components=['y'], d=['x1', 'x2'])
laplace_y = laplacian(output_, input_, components=['y'], d=['x1', 'x2'])
return - laplace_y - output_.extract(['u_param'])
def state_dirichlet(input_, output_):

View File

@@ -1,7 +1,7 @@
import torch
from pina.problem import SpatialProblem, ParametricProblem
from pina.operators import nabla
from pina.operators import laplacian
from pina import Span, Condition
# ===================================================== #
@@ -28,7 +28,7 @@ class ParametricPoisson(SpatialProblem, ParametricProblem):
force_term = torch.exp(
- 2*(input_.extract(['x']) - input_.extract(['mu1']))**2
- 2*(input_.extract(['y']) - input_.extract(['mu2']))**2)
return nabla(output_.extract(['u']), input_) - force_term
return laplacian(output_.extract(['u']), input_) - force_term
# define nill dirichlet boundary conditions
def nil_dirichlet(input_, output_):

View File

@@ -2,7 +2,7 @@ import numpy as np
import torch
from pina.problem import SpatialProblem
from pina.operators import nabla
from pina.operators import laplacian
from pina import Condition, Span
# ===================================================== #
@@ -26,8 +26,8 @@ class Poisson(SpatialProblem):
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
nabla_u = nabla(output_.extract(['u']), input_)
return nabla_u - force_term
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
# define nill dirichlet boundary conditions
def nil_dirichlet(input_, output_):

View File

@@ -2,7 +2,7 @@ import numpy as np
import torch
from pina.problem import SpatialProblem
from pina.operators import nabla, grad, div
from pina.operators import laplacian, grad, div
from pina import Condition, Span, LabelTensor
# ===================================================== #
@@ -25,11 +25,10 @@ class Stokes(SpatialProblem):
# define the momentum equation
def momentum(input_, output_):
nabla_ = torch.hstack((LabelTensor(nabla(output_.extract(['ux']), input_), ['x']),
LabelTensor(nabla(output_.extract(['uy']), input_), ['y'])))
return - nabla_ + grad(output_.extract(['p']), input_)
delta_ = torch.hstack((LabelTensor(laplacian(output_.extract(['ux']), input_), ['x']),
LabelTensor(laplacian(output_.extract(['uy']), input_), ['y'])))
return - delta_ + grad(output_.extract(['p']), input_)
# define the continuity equation
def continuity(input_, output_):
return div(output_.extract(['ux', 'uy']), input_)

View File

@@ -1,6 +1,6 @@
""" Module """
from .equation import Equation
from ..operators import grad, div, nabla
from ..operators import grad, div, laplacian
class FixedValue(Equation):
@@ -92,5 +92,5 @@ class Laplace(Equation):
are considered. Default is ``None``.
"""
def equation(input_, output_):
return nabla(output_, input_, components=components, d=d)
return laplacian(output_, input_, components=components, d=d)
super().__init__(equation)

View File

@@ -145,26 +145,26 @@ def div(output_, input_, components=None, d=None):
return div
def nabla(output_, input_, components=None, d=None, method='std'):
def laplacian(output_, input_, components=None, d=None, method='std'):
"""
Perform nabla (laplace) operator. The operator works for vectorial and
Compute Laplace operator. The operator works for vectorial and
scalar functions, with multiple input coordinates.
:param LabelTensor output_: the output tensor onto which computing the
nabla.
Laplacian.
:param LabelTensor input_: the input tensor with respect to which computing
the nabla.
the Laplacian.
:param list(str) components: the name of the output variables to calculate
the nabla for. It should be a subset of the output labels. If None,
the Laplacian for. It should be a subset of the output labels. If None,
all the output variables are considered. Default is None.
:param list(str) d: the name of the input variables on which the nabla
:param list(str) d: the name of the input variables on which the Laplacian
is calculated. d should be a subset of the input labels. If None, all
the input variables are considered. Default is None.
:param str method: used method to calculate nabla, defaults to 'std'.
:param str method: used method to calculate Laplacian, defaults to 'std'.
:raises ValueError: for vectorial field derivative with respect to
all coordinates must be performed.
:raises NotImplementedError: 'divgrad' not implemented as method.
:return: The tensor containing the result of the nabla operator.
:return: The tensor containing the result of the Laplacian operator.
:rtype: LabelTensor
"""
if d is None:
@@ -217,15 +217,15 @@ def advection(output_, input_, velocity_field, components=None, d=None):
with multiple input coordinates.
:param LabelTensor output_: the output tensor onto which computing the
nabla.
advection.
:param LabelTensor input_: the input tensor with respect to which computing
the nabla.
the advection.
:param str velocity_field: the name of the output variables which is used
as velocity field. It should be a subset of the output labels.
:param list(str) components: the name of the output variables to calculate
the nabla for. It should be a subset of the output labels. If None,
the Laplacian for. It should be a subset of the output labels. If None,
all the output variables are considered. Default is None.
:param list(str) d: the name of the input variables on which the nabla
:param list(str) d: the name of the input variables on which the advection
is calculated. d should be a subset of the input labels. If None, all
the input variables are considered. Default is None.
:return: the tensor containing the result of the advection operator.

View File

@@ -13,7 +13,7 @@ class TimeDependentProblem(AbstractProblem):
:Example:
>>> from pina.problem import SpatialProblem, TimeDependentProblem
>>> from pina.operators import grad, nabla
>>> from pina.operators import grad, laplacian
>>> from pina import Condition, Span
>>> import torch
>>>
@@ -26,8 +26,8 @@ class TimeDependentProblem(AbstractProblem):
>>> def wave_equation(input_, output_):
>>> u_t = grad(output_, input_, components=['u'], d=['t'])
>>> u_tt = grad(u_t, input_, components=['dudt'], d=['t'])
>>> nabla_u = nabla(output_, input_, components=['u'], d=['x'])
>>> return nabla_u - u_tt
>>> delta_u = laplacian(output_, input_, components=['u'], d=['x'])
>>> return delta_u - u_tt
>>>
>>> def nil_dirichlet(input_, output_):
>>> value = 0.0

View File

@@ -4,7 +4,7 @@ import pytest
from pina import LabelTensor, Condition, CartesianDomain, PINN
from pina.problem import SpatialProblem
from pina.model import FeedForward
from pina.operators import nabla
from pina.operators import laplacian
from pina.equation.equation_factory import FixedValue

View File

@@ -1,5 +1,5 @@
from pina.equation import Equation
from pina.operators import grad, nabla
from pina.operators import grad, laplacian
from pina import LabelTensor
import torch
import pytest
@@ -13,8 +13,8 @@ def eq1(input_, output_):
def eq2(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
nabla_u = nabla(output_.extract(['u1']), input_)
return nabla_u - force_term
delta_u = laplacian(output_.extract(['u1']), input_)
return delta_u - force_term
def foo():
pass

View File

@@ -1,5 +1,5 @@
from pina.equation import SystemEquation
from pina.operators import grad, nabla
from pina.operators import grad, laplacian
from pina import LabelTensor
import torch
import pytest
@@ -13,8 +13,8 @@ def eq1(input_, output_):
def eq2(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
nabla_u = nabla(output_.extract(['u1']), input_)
return nabla_u - force_term
delta_u = laplacian(output_.extract(['u1']), input_)
return delta_u - force_term
def foo():
pass

View File

@@ -4,7 +4,7 @@ import pytest
from pina import LabelTensor, Condition, CartesianDomain, PINN
from pina.problem import SpatialProblem
from pina.model import FeedForward
from pina.operators import nabla
from pina.operators import laplacian

View File

@@ -2,7 +2,7 @@ import torch
import pytest
from pina import LabelTensor
from pina.operators import grad, div, nabla
from pina.operators import grad, div, laplacian
def func_vec(x):
return x**2
@@ -41,13 +41,13 @@ def test_div_vector_output():
grad_tensor_v = div(tensor_v, inp, components=['a', 'b'], d=['x', 'mu'])
assert grad_tensor_v.shape == (inp.shape[0], 1)
def test_nabla_scalar_output():
laplace_tensor_v = nabla(tensor_s, inp, components=['a'], d=['x', 'y'])
def test_laplacian_scalar_output():
laplace_tensor_v = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
assert laplace_tensor_v.shape == tensor_s.shape
def test_nabla_vector_output():
laplace_tensor_v = nabla(tensor_v, inp)
def test_laplacian_vector_output():
laplace_tensor_v = laplacian(tensor_v, inp)
assert laplace_tensor_v.shape == tensor_v.shape
laplace_tensor_v = nabla(tensor_v, inp, components=['a', 'b'], d=['x', 'y'])
laplace_tensor_v = laplacian(tensor_v, inp, components=['a', 'b'], d=['x', 'y'])
assert laplace_tensor_v.shape == tensor_v.extract(['a', 'b']).shape

View File

@@ -2,7 +2,7 @@ import torch
import pytest
from pina.problem import SpatialProblem
from pina.operators import nabla
from pina.operators import laplacian
from pina import LabelTensor, Condition
from pina.geometry import CartesianDomain
from pina.equation.equation import Equation
@@ -12,8 +12,8 @@ from pina.equation.equation_factory import FixedValue
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
nabla_u = nabla(output_.extract(['u']), input_)
return nabla_u - force_term
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y'])

View File

@@ -2,7 +2,7 @@ import torch
import pytest
from pina.problem import SpatialProblem
from pina.operators import nabla
from pina.operators import laplacian
from pina.geometry import CartesianDomain
from pina import Condition, LabelTensor, PINN
from pina.trainer import Trainer
@@ -16,8 +16,8 @@ from pina.loss import LpLoss
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
nabla_u = nabla(output_.extract(['u']), input_)
return nabla_u - force_term
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])

View File

@@ -25,7 +25,7 @@ import torch
from torch.nn import Softplus
from pina.problem import SpatialProblem
from pina.operators import nabla
from pina.operators import laplacian
from pina.model import FeedForward
from pina import Condition, Span, PINN, LabelTensor, Plotter
@@ -43,8 +43,8 @@ class Poisson(SpatialProblem):
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
nabla_u = nabla(output_, input_, components=['u'], d=['x', 'y'])
return nabla_u - force_term
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
return delta_u - force_term
def nil_dirichlet(input_, output_):
value = 0.0

View File

@@ -27,7 +27,7 @@
import torch
from pina.problem import SpatialProblem, TimeDependentProblem
from pina.operators import nabla, grad
from pina.operators import laplacian, grad
from pina.model import Network
from pina import Condition, Span, PINN, Plotter
@@ -45,8 +45,8 @@ class Wave(TimeDependentProblem, SpatialProblem):
def wave_equation(input_, output_):
u_t = grad(output_, input_, components=['u'], d=['t'])
u_tt = grad(u_t, input_, components=['dudt'], d=['t'])
nabla_u = nabla(output_, input_, components=['u'], d=['x', 'y'])
return nabla_u - u_tt
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
return delta_u - u_tt
def nil_dirichlet(input_, output_):
value = 0.0