Improve differential operators (#528)
* Improve grad logic and fix issues * Add operators' fast versions * Fix bug in laplacian + new tests + restructuring Co-authored-by: Dario Coscia <dariocos99@gmail.com> * fix advection bug --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Dario Coscia
parent
ce0c033de1
commit
485c8dd789
564
pina/operator.py
564
pina/operator.py
@@ -10,10 +10,309 @@ Each differential operator takes the following inputs:
|
|||||||
- A tensor with respect to which the operator is computed.
|
- A tensor with respect to which the operator is computed.
|
||||||
- The names of the output variables for which the operator is evaluated.
|
- The names of the output variables for which the operator is evaluated.
|
||||||
- The names of the variables with respect to which the operator is computed.
|
- The names of the variables with respect to which the operator is computed.
|
||||||
|
|
||||||
|
Each differential operator has its fast version, which performs no internal
|
||||||
|
checks on input and output tensors. For these methods, the user is always
|
||||||
|
required to specify both ``components`` and ``d`` as lists of strings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pina.label_tensor import LabelTensor
|
from .label_tensor import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
|
def _check_values(output_, input_, components, d):
|
||||||
|
"""
|
||||||
|
Perform checks on arguments of differential operators.
|
||||||
|
|
||||||
|
:param LabelTensor output_: The output tensor on which the operator is
|
||||||
|
computed.
|
||||||
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
|
operator is computed.
|
||||||
|
:param components: The names of the output variables for which to compute
|
||||||
|
the operator. It must be a subset of the output labels.
|
||||||
|
If ``None``, all output variables are considered. Default is ``None``.
|
||||||
|
:type components: str | list[str]
|
||||||
|
:param d: The names of the input variables with respect to which the
|
||||||
|
operator is computed. It must be a subset of the input labels.
|
||||||
|
If ``None``, all input variables are considered. Default is ``None``.
|
||||||
|
:type d: str | list[str]
|
||||||
|
:raises TypeError: If the input tensor is not a LabelTensor.
|
||||||
|
:raises TypeError: If the output tensor is not a LabelTensor.
|
||||||
|
:raises RuntimeError: If derivative labels are missing from the ``input_``.
|
||||||
|
:raises RuntimeError: If component labels are missing from the ``output_``.
|
||||||
|
:return: The components and d lists.
|
||||||
|
:rtype: tuple[list[str], list[str]]
|
||||||
|
"""
|
||||||
|
# Check if the input is a LabelTensor
|
||||||
|
if not isinstance(input_, LabelTensor):
|
||||||
|
raise TypeError("Input must be a LabelTensor.")
|
||||||
|
|
||||||
|
# Check if the output is a LabelTensor
|
||||||
|
if not isinstance(output_, LabelTensor):
|
||||||
|
raise TypeError("Output must be a LabelTensor.")
|
||||||
|
|
||||||
|
# If no labels are provided, use all labels
|
||||||
|
d = d or input_.labels
|
||||||
|
components = components or output_.labels
|
||||||
|
|
||||||
|
# Convert to list if not already
|
||||||
|
d = d if isinstance(d, list) else [d]
|
||||||
|
components = components if isinstance(components, list) else [components]
|
||||||
|
|
||||||
|
# Check if all labels are present in the input tensor
|
||||||
|
if not all(di in input_.labels for di in d):
|
||||||
|
raise RuntimeError("Derivative labels missing from input tensor.")
|
||||||
|
|
||||||
|
# Check if all labels are present in the output tensor
|
||||||
|
if not all(c in output_.labels for c in components):
|
||||||
|
raise RuntimeError("Component label missing from output tensor.")
|
||||||
|
|
||||||
|
return components, d
|
||||||
|
|
||||||
|
|
||||||
|
def _scalar_grad(output_, input_, d):
|
||||||
|
"""
|
||||||
|
Compute the gradient of a scalar-valued ``output_``.
|
||||||
|
|
||||||
|
:param LabelTensor output_: The output tensor on which the gradient is
|
||||||
|
computed. It must be a column tensor.
|
||||||
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
|
gradient is computed.
|
||||||
|
:param list[str] d: The names of the input variables with respect to
|
||||||
|
which the gradient is computed. It must be a subset of the input
|
||||||
|
labels. If ``None``, all input variables are considered.
|
||||||
|
:return: The computed gradient tensor.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
grad_out = torch.autograd.grad(
|
||||||
|
outputs=output_,
|
||||||
|
inputs=input_,
|
||||||
|
grad_outputs=torch.ones_like(output_),
|
||||||
|
create_graph=True,
|
||||||
|
retain_graph=True,
|
||||||
|
allow_unused=True,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return grad_out[..., [input_.labels.index(i) for i in d]]
|
||||||
|
|
||||||
|
|
||||||
|
def _scalar_laplacian(output_, input_, d):
|
||||||
|
"""
|
||||||
|
Compute the laplacian of a scalar-valued ``output_``.
|
||||||
|
|
||||||
|
:param LabelTensor output_: The output tensor on which the laplacian is
|
||||||
|
computed. It must be a column tensor.
|
||||||
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
|
laplacian is computed.
|
||||||
|
:param list[str] d: The names of the input variables with respect to
|
||||||
|
which the laplacian is computed. It must be a subset of the input
|
||||||
|
labels. If ``None``, all input variables are considered.
|
||||||
|
:return: The computed laplacian tensor.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
first_grad = fast_grad(
|
||||||
|
output_=output_, input_=input_, components=output_.labels, d=d
|
||||||
|
)
|
||||||
|
second_grad = fast_grad(
|
||||||
|
output_=first_grad, input_=input_, components=first_grad.labels, d=d
|
||||||
|
)
|
||||||
|
labels_to_extract = [f"d{c}d{d_}" for c, d_ in zip(first_grad.labels, d)]
|
||||||
|
return torch.sum(
|
||||||
|
second_grad.extract(labels_to_extract), dim=-1, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_grad(output_, input_, components, d):
|
||||||
|
"""
|
||||||
|
Compute the gradient of the ``output_`` with respect to the ``input``.
|
||||||
|
|
||||||
|
Unlike ``grad``, this function performs no internal checks on input and
|
||||||
|
output tensors. The user is required to specify both ``components`` and
|
||||||
|
``d`` as lists of strings. It is designed to enhance computation speed.
|
||||||
|
|
||||||
|
This operator supports both vector-valued and scalar-valued functions with
|
||||||
|
one or multiple input coordinates.
|
||||||
|
|
||||||
|
:param LabelTensor output_: The output tensor on which the gradient is
|
||||||
|
computed.
|
||||||
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
|
gradient is computed.
|
||||||
|
:param list[str] components: The names of the output variables for which to
|
||||||
|
compute the gradient. It must be a subset of the output labels.
|
||||||
|
:param list[str] d: The names of the input variables with respect to which
|
||||||
|
the gradient is computed. It must be a subset of the input labels.
|
||||||
|
:return: The computed gradient tensor.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# Scalar gradient
|
||||||
|
if output_.shape[-1] == 1:
|
||||||
|
return LabelTensor(
|
||||||
|
_scalar_grad(output_=output_, input_=input_, d=d),
|
||||||
|
labels=[f"d{output_.labels[0]}d{i}" for i in d],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vector gradient
|
||||||
|
grads = torch.cat(
|
||||||
|
[
|
||||||
|
_scalar_grad(output_=output_.extract(c), input_=input_, d=d)
|
||||||
|
for c in components
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LabelTensor(
|
||||||
|
grads, labels=[f"d{c}d{i}" for c in components for i in d]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_div(output_, input_, components, d):
|
||||||
|
"""
|
||||||
|
Compute the divergence of the ``output_`` with respect to ``input``.
|
||||||
|
|
||||||
|
Unlike ``div``, this function performs no internal checks on input and
|
||||||
|
output tensors. The user is required to specify both ``components`` and
|
||||||
|
``d`` as lists of strings. It is designed to enhance computation speed.
|
||||||
|
|
||||||
|
This operator supports vector-valued functions with multiple input
|
||||||
|
coordinates.
|
||||||
|
|
||||||
|
:param LabelTensor output_: The output tensor on which the divergence is
|
||||||
|
computed.
|
||||||
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
|
divergence is computed.
|
||||||
|
:param list[str] components: The names of the output variables for which to
|
||||||
|
compute the divergence. It must be a subset of the output labels.
|
||||||
|
:param list[str] d: The names of the input variables with respect to which
|
||||||
|
the divergence is computed. It must be a subset of the input labels.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
grad_out = fast_grad(
|
||||||
|
output_=output_, input_=input_, components=components, d=d
|
||||||
|
)
|
||||||
|
tensors_to_sum = [
|
||||||
|
grad_out.extract(f"d{c}d{d_}") for c, d_ in zip(components, d)
|
||||||
|
]
|
||||||
|
|
||||||
|
return LabelTensor.summation(tensors_to_sum)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_laplacian(output_, input_, components, d, method="std"):
|
||||||
|
"""
|
||||||
|
Compute the laplacian of the ``output_`` with respect to ``input``.
|
||||||
|
|
||||||
|
Unlike ``laplacian``, this function performs no internal checks on input and
|
||||||
|
output tensors. The user is required to specify both ``components`` and
|
||||||
|
``d`` as lists of strings. It is designed to enhance computation speed.
|
||||||
|
|
||||||
|
This operator supports both vector-valued and scalar-valued functions with
|
||||||
|
one or multiple input coordinates.
|
||||||
|
|
||||||
|
:param LabelTensor output_: The output tensor on which the laplacian is
|
||||||
|
computed.
|
||||||
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
|
laplacian is computed.
|
||||||
|
:param list[str] components: The names of the output variables for which to
|
||||||
|
compute the laplacian. It must be a subset of the output labels.
|
||||||
|
:param list[str] d: The names of the input variables with respect to which
|
||||||
|
the laplacian is computed. It must be a subset of the input labels.
|
||||||
|
:param str method: The method used to compute the Laplacian. Available
|
||||||
|
methods are ``std`` and ``divgrad``. The ``std`` method computes the
|
||||||
|
trace of the Hessian matrix, while the ``divgrad`` method computes the
|
||||||
|
divergence of the gradient. Default is ``std``.
|
||||||
|
:return: The computed laplacian tensor.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# Scalar laplacian
|
||||||
|
if output_.shape[-1] == 1:
|
||||||
|
return LabelTensor(
|
||||||
|
_scalar_laplacian(output_=output_, input_=input_, d=d),
|
||||||
|
labels=[f"dd{c}" for c in components],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the result tensor and its labels
|
||||||
|
labels = [f"dd{c}" for c in components]
|
||||||
|
result = torch.empty(
|
||||||
|
input_.shape[0], len(components), device=output_.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vector laplacian
|
||||||
|
if method == "std":
|
||||||
|
result = torch.cat(
|
||||||
|
[
|
||||||
|
_scalar_laplacian(
|
||||||
|
output_=output_.extract(c), input_=input_, d=d
|
||||||
|
)
|
||||||
|
for c in components
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif method == "divgrad":
|
||||||
|
grads = fast_grad(
|
||||||
|
output_=output_, input_=input_, components=components, d=d
|
||||||
|
)
|
||||||
|
result = torch.cat(
|
||||||
|
[
|
||||||
|
fast_div(
|
||||||
|
output_=grads,
|
||||||
|
input_=input_,
|
||||||
|
components=[f"d{c}d{i}" for i in d],
|
||||||
|
d=d,
|
||||||
|
)
|
||||||
|
for c in components
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid method. Available methods are ``std`` and ``divgrad``."
|
||||||
|
)
|
||||||
|
|
||||||
|
return LabelTensor(result, labels=labels)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_advection(output_, input_, velocity_field, components, d):
|
||||||
|
"""
|
||||||
|
Perform the advection operation on the ``output_`` with respect to the
|
||||||
|
``input``. This operator support vector-valued functions with multiple input
|
||||||
|
coordinates.
|
||||||
|
|
||||||
|
Unlike ``advection``, this function performs no internal checks on input and
|
||||||
|
output tensors. The user is required to specify both ``components`` and
|
||||||
|
``d`` as lists of strings. It is designed to enhance computation speed.
|
||||||
|
|
||||||
|
:param LabelTensor output_: The output tensor on which the advection is
|
||||||
|
computed.
|
||||||
|
:param LabelTensor input_: the input tensor with respect to which advection
|
||||||
|
is computed.
|
||||||
|
:param str velocity_field: The name of the output variable used as velocity
|
||||||
|
field. It must be chosen among the output labels.
|
||||||
|
:param list[str] components: The names of the output variables for which to
|
||||||
|
compute the advection. It must be a subset of the output labels.
|
||||||
|
:param list[str] d: The names of the input variables with respect to which
|
||||||
|
the advection is computed. It must be a subset of the input labels.
|
||||||
|
:return: The computed advection tensor.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
# Add a dimension to the velocity field for following operations
|
||||||
|
velocity = output_.extract(velocity_field).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Remove the velocity field from the components
|
||||||
|
filter_components = [c for c in components if c != velocity_field]
|
||||||
|
|
||||||
|
# Compute the gradient
|
||||||
|
grads = fast_grad(
|
||||||
|
output_=output_, input_=input_, components=filter_components, d=d
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape into [..., len(filter_components), len(d)]
|
||||||
|
tmp = grads.reshape(*output_.shape[:-1], len(filter_components), len(d))
|
||||||
|
|
||||||
|
# Transpose to [..., len(d), len(filter_components)]
|
||||||
|
tmp = tmp.transpose(-1, -2)
|
||||||
|
|
||||||
|
return (tmp * velocity).sum(dim=tmp.tensor.ndim - 2)
|
||||||
|
|
||||||
|
|
||||||
def grad(output_, input_, components=None, d=None):
|
def grad(output_, input_, components=None, d=None):
|
||||||
@@ -27,95 +326,25 @@ def grad(output_, input_, components=None, d=None):
|
|||||||
computed.
|
computed.
|
||||||
:param LabelTensor input_: The input tensor with respect to which the
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
gradient is computed.
|
gradient is computed.
|
||||||
:param components: The names of the output variables for which to
|
:param components: The names of the output variables for which to compute
|
||||||
compute the gradient. It must be a subset of the output labels.
|
the gradient. It must be a subset of the output labels.
|
||||||
If ``None``, all output variables are considered. Default is ``None``.
|
If ``None``, all output variables are considered. Default is ``None``.
|
||||||
:type components: str | list[str]
|
:type components: str | list[str]
|
||||||
:param d: The names of the input variables with respect to which
|
:param d: The names of the input variables with respect to which the
|
||||||
the gradient is computed. It must be a subset of the input labels.
|
gradient is computed. It must be a subset of the input labels.
|
||||||
If ``None``, all input variables are considered. Default is ``None``.
|
If ``None``, all input variables are considered. Default is ``None``.
|
||||||
:type d: str | list[str]
|
:type d: str | list[str]
|
||||||
:raises TypeError: If the input tensor is not a LabelTensor.
|
:raises TypeError: If the input tensor is not a LabelTensor.
|
||||||
:raises RuntimeError: If the output is a scalar field and the components
|
:raises TypeError: If the output tensor is not a LabelTensor.
|
||||||
are not equal to the output labels.
|
:raises RuntimeError: If derivative labels are missing from the ``input_``.
|
||||||
:raises NotImplementedError: If the output is neither a vector field nor a
|
:raises RuntimeError: If component labels are missing from the ``output_``.
|
||||||
scalar field.
|
|
||||||
:return: The computed gradient tensor.
|
:return: The computed gradient tensor.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
components, d = _check_values(
|
||||||
def grad_scalar_output(output_, input_, d):
|
output_=output_, input_=input_, components=components, d=d
|
||||||
"""
|
)
|
||||||
Compute the gradient of a scalar-valued ``output_``.
|
return fast_grad(output_=output_, input_=input_, components=components, d=d)
|
||||||
|
|
||||||
:param LabelTensor output_: The output tensor on which the gradient is
|
|
||||||
computed. It must be a column tensor.
|
|
||||||
:param LabelTensor input_: The input tensor with respect to which the
|
|
||||||
gradient is computed.
|
|
||||||
:param d: The names of the input variables with respect to
|
|
||||||
which the gradient is computed. It must be a subset of the input
|
|
||||||
labels. If ``None``, all input variables are considered.
|
|
||||||
:type d: str | list[str]
|
|
||||||
:raises RuntimeError: If a vectorial function is passed.
|
|
||||||
:raises RuntimeError: If missing derivative labels.
|
|
||||||
:return: The computed gradient tensor.
|
|
||||||
:rtype: LabelTensor
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(output_.labels) != 1:
|
|
||||||
raise RuntimeError("only scalar function can be differentiated")
|
|
||||||
if not all(di in input_.labels for di in d):
|
|
||||||
raise RuntimeError("derivative labels missing from input tensor")
|
|
||||||
|
|
||||||
output_fieldname = output_.labels[0]
|
|
||||||
gradients = torch.autograd.grad(
|
|
||||||
output_,
|
|
||||||
input_,
|
|
||||||
grad_outputs=torch.ones(
|
|
||||||
output_.size(), dtype=output_.dtype, device=output_.device
|
|
||||||
),
|
|
||||||
create_graph=True,
|
|
||||||
retain_graph=True,
|
|
||||||
allow_unused=True,
|
|
||||||
)[0]
|
|
||||||
gradients.labels = input_.stored_labels
|
|
||||||
gradients = gradients[..., [input_.labels.index(i) for i in d]]
|
|
||||||
gradients.labels = [f"d{output_fieldname}d{i}" for i in d]
|
|
||||||
return gradients
|
|
||||||
|
|
||||||
if not isinstance(input_, LabelTensor):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
if d is None:
|
|
||||||
d = input_.labels
|
|
||||||
|
|
||||||
if components is None:
|
|
||||||
components = output_.labels
|
|
||||||
|
|
||||||
if not isinstance(components, list):
|
|
||||||
components = [components]
|
|
||||||
|
|
||||||
if not isinstance(d, list):
|
|
||||||
d = [d]
|
|
||||||
|
|
||||||
if output_.shape[1] == 1: # scalar output ################################
|
|
||||||
|
|
||||||
if components != output_.labels:
|
|
||||||
raise RuntimeError
|
|
||||||
gradients = grad_scalar_output(output_, input_, d)
|
|
||||||
|
|
||||||
elif (
|
|
||||||
output_.shape[output_.ndim - 1] >= 2
|
|
||||||
): # vector output ##############################
|
|
||||||
tensor_to_cat = []
|
|
||||||
for i, c in enumerate(components):
|
|
||||||
c_output = output_.extract([c])
|
|
||||||
tensor_to_cat.append(grad_scalar_output(c_output, input_, d))
|
|
||||||
gradients = LabelTensor.cat(tensor_to_cat, dim=output_.tensor.ndim - 1)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
return gradients
|
|
||||||
|
|
||||||
|
|
||||||
def div(output_, input_, components=None, d=None):
|
def div(output_, input_, components=None, d=None):
|
||||||
@@ -129,51 +358,31 @@ def div(output_, input_, components=None, d=None):
|
|||||||
computed.
|
computed.
|
||||||
:param LabelTensor input_: The input tensor with respect to which the
|
:param LabelTensor input_: The input tensor with respect to which the
|
||||||
divergence is computed.
|
divergence is computed.
|
||||||
:param components: The names of the output variables for which to
|
:param components: The names of the output variables for which to compute
|
||||||
compute the divergence. It must be a subset of the output labels.
|
the divergence. It must be a subset of the output labels.
|
||||||
If ``None``, all output variables are considered. Default is ``None``.
|
If ``None``, all output variables are considered. Default is ``None``.
|
||||||
:type components: str | list[str]
|
:type components: str | list[str]
|
||||||
:param d: The names of the input variables with respect to which
|
:param d: The names of the input variables with respect to which the
|
||||||
the divergence is computed. It must be a subset of the input labels.
|
divergence is computed. It must be a subset of the input labels.
|
||||||
If ``None``, all input variables are considered. Default is ``None``.
|
If ``None``, all input variables are considered. Default is ``None``.
|
||||||
:type d: str | list[str]
|
:type components: str | list[str]
|
||||||
:raises TypeError: If the input tensor is not a LabelTensor.
|
:raises TypeError: If the input tensor is not a LabelTensor.
|
||||||
:raises ValueError: If the output is a scalar field.
|
:raises TypeError: If the output tensor is not a LabelTensor.
|
||||||
:raises ValueError: If the number of components is not equal to the number
|
:raises ValueError: If the length of ``components`` and ``d`` do not match.
|
||||||
of input variables.
|
|
||||||
:return: The computed divergence tensor.
|
:return: The computed divergence tensor.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
if not isinstance(input_, LabelTensor):
|
components, d = _check_values(
|
||||||
raise TypeError
|
output_=output_, input_=input_, components=components, d=d
|
||||||
|
)
|
||||||
if d is None:
|
|
||||||
d = input_.labels
|
|
||||||
|
|
||||||
if components is None:
|
|
||||||
components = output_.labels
|
|
||||||
|
|
||||||
if not isinstance(components, list):
|
|
||||||
components = [components]
|
|
||||||
|
|
||||||
if not isinstance(d, list):
|
|
||||||
d = [d]
|
|
||||||
|
|
||||||
if output_.shape[1] < 2 or len(components) < 2:
|
|
||||||
raise ValueError("div supported only for vector fields")
|
|
||||||
|
|
||||||
|
# Components and d must be of the same length
|
||||||
if len(components) != len(d):
|
if len(components) != len(d):
|
||||||
raise ValueError
|
raise ValueError(
|
||||||
|
"Divergence requires components and d to be of the same length."
|
||||||
|
)
|
||||||
|
|
||||||
grad_output = grad(output_, input_, components, d)
|
return fast_div(output_=output_, input_=input_, components=components, d=d)
|
||||||
labels = [None] * len(components)
|
|
||||||
tensors_to_sum = []
|
|
||||||
for i, (c, d_) in enumerate(zip(components, d)):
|
|
||||||
c_fields = f"d{c}d{d_}"
|
|
||||||
tensors_to_sum.append(grad_output.extract(c_fields))
|
|
||||||
labels[i] = c_fields
|
|
||||||
div_result = LabelTensor.summation(tensors_to_sum)
|
|
||||||
return div_result
|
|
||||||
|
|
||||||
|
|
||||||
def laplacian(output_, input_, components=None, d=None, method="std"):
|
def laplacian(output_, input_, components=None, d=None, method="std"):
|
||||||
@@ -195,71 +404,22 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
|
|||||||
the laplacian is computed. It must be a subset of the input labels.
|
the laplacian is computed. It must be a subset of the input labels.
|
||||||
If ``None``, all input variables are considered. Default is ``None``.
|
If ``None``, all input variables are considered. Default is ``None``.
|
||||||
:type d: str | list[str]
|
:type d: str | list[str]
|
||||||
:param str method: The method used to compute the Laplacian. Default is
|
:param str method: The method used to compute the Laplacian. Available
|
||||||
``std``.
|
methods are ``std`` and ``divgrad``. The ``std`` method computes the
|
||||||
:raises NotImplementedError: If ``std=divgrad``.
|
trace of the Hessian matrix, while the ``divgrad`` method computes the
|
||||||
|
divergence of the gradient. Default is ``std``.
|
||||||
|
:raises TypeError: If the input tensor is not a LabelTensor.
|
||||||
|
:raises TypeError: If the output tensor is not a LabelTensor.
|
||||||
|
:raises ValueError: If the passed method is neither ``std`` nor ``divgrad``.
|
||||||
:return: The computed laplacian tensor.
|
:return: The computed laplacian tensor.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
components, d = _check_values(
|
||||||
def scalar_laplace(output_, input_, components, d):
|
output_=output_, input_=input_, components=components, d=d
|
||||||
"""
|
)
|
||||||
Compute the laplacian of a scalar-valued ``output_``.
|
return fast_laplacian(
|
||||||
|
output_=output_, input_=input_, components=components, d=d
|
||||||
:param LabelTensor output_: The output tensor on which the laplacian is
|
)
|
||||||
computed. It must be a column tensor.
|
|
||||||
:param LabelTensor input_: The input tensor with respect to which the
|
|
||||||
laplacian is computed.
|
|
||||||
:param components: The names of the output variables for which
|
|
||||||
to compute the laplacian. It must be a subset of the output labels.
|
|
||||||
If ``None``, all output variables are considered.
|
|
||||||
:type components: str | list[str]
|
|
||||||
:param d: The names of the input variables with respect to
|
|
||||||
which the laplacian is computed. It must be a subset of the input
|
|
||||||
labels. If ``None``, all input variables are considered.
|
|
||||||
:type d: str | list[str]
|
|
||||||
:return: The computed laplacian tensor.
|
|
||||||
:rtype: LabelTensor
|
|
||||||
"""
|
|
||||||
|
|
||||||
grad_output = grad(output_, input_, components=components, d=d)
|
|
||||||
result = torch.zeros(output_.shape[0], 1, device=output_.device)
|
|
||||||
|
|
||||||
for i, label in enumerate(grad_output.labels):
|
|
||||||
gg = grad(grad_output, input_, d=d, components=[label])
|
|
||||||
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(i)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
if d is None:
|
|
||||||
d = input_.labels
|
|
||||||
|
|
||||||
if components is None:
|
|
||||||
components = output_.labels
|
|
||||||
|
|
||||||
if not isinstance(components, list):
|
|
||||||
components = [components]
|
|
||||||
|
|
||||||
if not isinstance(d, list):
|
|
||||||
d = [d]
|
|
||||||
|
|
||||||
if method == "divgrad":
|
|
||||||
raise NotImplementedError("divgrad not implemented as method")
|
|
||||||
|
|
||||||
if method == "std":
|
|
||||||
|
|
||||||
result = torch.empty(
|
|
||||||
input_.shape[0], len(components), device=output_.device
|
|
||||||
)
|
|
||||||
labels = [None] * len(components)
|
|
||||||
for idx, c in enumerate(components):
|
|
||||||
result[:, idx] = scalar_laplace(output_, input_, [c], d).flatten()
|
|
||||||
labels[idx] = f"dd{c}"
|
|
||||||
|
|
||||||
result = result.as_subclass(LabelTensor)
|
|
||||||
result.labels = labels
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def advection(output_, input_, velocity_field, components=None, d=None):
|
def advection(output_, input_, velocity_field, components=None, d=None):
|
||||||
@@ -274,34 +434,34 @@ def advection(output_, input_, velocity_field, components=None, d=None):
|
|||||||
is computed.
|
is computed.
|
||||||
:param str velocity_field: The name of the output variable used as velocity
|
:param str velocity_field: The name of the output variable used as velocity
|
||||||
field. It must be chosen among the output labels.
|
field. It must be chosen among the output labels.
|
||||||
:param components: The names of the output variables for which
|
:param components: The names of the output variables for which to compute
|
||||||
to compute the advection. It must be a subset of the output labels.
|
the advection. It must be a subset of the output labels.
|
||||||
If ``None``, all output variables are considered. Default is ``None``.
|
If ``None``, all output variables are considered. Default is ``None``.
|
||||||
:type components: str | list[str]
|
:type components: str | list[str]
|
||||||
:param d: The names of the input variables with respect to which
|
:param d: The names of the input variables with respect to which the
|
||||||
the advection is computed. It must be a subset of the input labels.
|
advection is computed. It must be a subset of the input labels.
|
||||||
If ``None``, all input variables are considered. Default is ``None``.
|
If ``None``, all input variables are considered. Default is ``None``.
|
||||||
:type d: str | list[str]
|
:type d: str | list[str]
|
||||||
|
:raises TypeError: If the input tensor is not a LabelTensor.
|
||||||
|
:raises TypeError: If the output tensor is not a LabelTensor.
|
||||||
|
:raises RuntimeError: If the velocity field is not in the output labels.
|
||||||
:return: The computed advection tensor.
|
:return: The computed advection tensor.
|
||||||
:rtype: LabelTensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
if d is None:
|
components, d = _check_values(
|
||||||
d = input_.labels
|
output_=output_, input_=input_, components=components, d=d
|
||||||
|
|
||||||
if components is None:
|
|
||||||
components = output_.labels
|
|
||||||
|
|
||||||
if not isinstance(components, list):
|
|
||||||
components = [components]
|
|
||||||
|
|
||||||
if not isinstance(d, list):
|
|
||||||
d = [d]
|
|
||||||
|
|
||||||
tmp = (
|
|
||||||
grad(output_, input_, components, d)
|
|
||||||
.reshape(-1, len(components), len(d))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tmp *= output_.extract(velocity_field)
|
# Check if velocity field is present in the output labels
|
||||||
return tmp.sum(dim=2).T
|
if velocity_field not in output_.labels:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Velocity {velocity_field} is not present in the output labels."
|
||||||
|
)
|
||||||
|
|
||||||
|
return fast_advection(
|
||||||
|
output_=output_,
|
||||||
|
input_=input_,
|
||||||
|
velocity_field=velocity_field,
|
||||||
|
components=components,
|
||||||
|
d=d,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,205 +1,317 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pina import LabelTensor
|
from pina import LabelTensor
|
||||||
from pina.operator import grad, div, laplacian
|
from pina.operator import grad, div, laplacian, advection
|
||||||
|
|
||||||
|
|
||||||
def func_vector(x):
|
class Function(object):
|
||||||
return x**2
|
|
||||||
|
def __iter__(self):
|
||||||
|
functions = [
|
||||||
|
(
|
||||||
|
getattr(self, f"{name}_input"),
|
||||||
|
getattr(self, f"{name}"),
|
||||||
|
getattr(self, f"{name}_grad"),
|
||||||
|
getattr(self, f"{name}_div"),
|
||||||
|
getattr(self, f"{name}_lap"),
|
||||||
|
)
|
||||||
|
for name in [
|
||||||
|
"scalar_scalar",
|
||||||
|
"scalar_vector",
|
||||||
|
"vector_scalar",
|
||||||
|
"vector_vector",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
return iter(functions)
|
||||||
|
|
||||||
|
# Scalar to scalar function
|
||||||
|
@staticmethod
|
||||||
|
def scalar_scalar(x):
|
||||||
|
return x**2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_scalar_grad(x):
|
||||||
|
return 2 * x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_scalar_div(x):
|
||||||
|
return 2 * x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_scalar_lap(x):
|
||||||
|
return 2 * torch.ones_like(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_scalar_input():
|
||||||
|
input_ = torch.rand((20, 1), requires_grad=True)
|
||||||
|
return LabelTensor(input_, ["x"])
|
||||||
|
|
||||||
|
# Scalar to vector function
|
||||||
|
@staticmethod
|
||||||
|
def scalar_vector(x):
|
||||||
|
u = x**2
|
||||||
|
v = x**3 + x
|
||||||
|
return torch.cat((u, v), dim=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_vector_grad(x):
|
||||||
|
u = 2 * x
|
||||||
|
v = 3 * x**2 + 1
|
||||||
|
return torch.cat((u, v), dim=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_vector_div(x):
|
||||||
|
return ValueError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_vector_lap(x):
|
||||||
|
u = 2 * torch.ones_like(x)
|
||||||
|
v = 6 * x
|
||||||
|
return torch.cat((u, v), dim=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scalar_vector_input():
|
||||||
|
input_ = torch.rand((20, 1), requires_grad=True)
|
||||||
|
return LabelTensor(input_, ["x"])
|
||||||
|
|
||||||
|
# Vector to scalar function
|
||||||
|
@staticmethod
|
||||||
|
def vector_scalar(x):
|
||||||
|
return torch.prod(x**2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_scalar_grad(x):
|
||||||
|
return 2 * torch.prod(x**2, dim=-1, keepdim=True) / x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_scalar_div(x):
|
||||||
|
return ValueError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_scalar_lap(x):
|
||||||
|
return 2 * torch.sum(
|
||||||
|
torch.prod(x**2, dim=-1, keepdim=True) / x**2,
|
||||||
|
dim=-1,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_scalar_input():
|
||||||
|
input_ = torch.rand((20, 2), requires_grad=True)
|
||||||
|
return LabelTensor(input_, ["x", "yy"])
|
||||||
|
|
||||||
|
# Vector to vector function
|
||||||
|
@staticmethod
|
||||||
|
def vector_vector(x):
|
||||||
|
u = torch.prod(x**2, dim=-1, keepdim=True)
|
||||||
|
v = torch.sum(x**2, dim=-1, keepdim=True)
|
||||||
|
return torch.cat((u, v), dim=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_vector_grad(x):
|
||||||
|
u = 2 * torch.prod(x**2, dim=-1, keepdim=True) / x
|
||||||
|
v = 2 * x
|
||||||
|
return torch.cat((u, v), dim=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_vector_div(x):
|
||||||
|
u = 2 * torch.prod(x**2, dim=-1, keepdim=True) / x[..., 0]
|
||||||
|
v = 2 * x[..., 1]
|
||||||
|
return u + v
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_vector_lap(x):
|
||||||
|
u = torch.sum(
|
||||||
|
2 * torch.prod(x**2, dim=-1, keepdim=True) / x**2,
|
||||||
|
dim=-1,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
v = 2 * x.shape[-1] * torch.ones_like(u)
|
||||||
|
return torch.cat((u, v), dim=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vector_vector_input():
|
||||||
|
input_ = torch.rand((20, 2), requires_grad=True)
|
||||||
|
return LabelTensor(input_, ["x", "yy"])
|
||||||
|
|
||||||
|
|
||||||
def func_scalar(x):
|
@pytest.mark.parametrize(
|
||||||
x_ = x.extract(["x"])
|
"f",
|
||||||
y_ = x.extract(["y"])
|
Function(),
|
||||||
z_ = x.extract(["z"])
|
ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"],
|
||||||
return x_**2 + y_**2 + z_**2
|
)
|
||||||
|
def test_gradient(f):
|
||||||
|
|
||||||
|
# Unpack the function
|
||||||
|
func_input, func, func_grad, _, _ = f
|
||||||
|
|
||||||
data = torch.rand((20, 3))
|
# Define input and output
|
||||||
inp = LabelTensor(data, ["x", "y", "z"]).requires_grad_(True)
|
input_ = func_input()
|
||||||
labels = ["a", "b", "c"]
|
output_ = func(input_)
|
||||||
tensor_v = LabelTensor(func_vector(inp), labels)
|
labels = [f"u{i}" for i in range(output_.shape[-1])]
|
||||||
tensor_s = LabelTensor(func_scalar(inp).reshape(-1, 1), labels[0])
|
output_ = LabelTensor(output_, labels)
|
||||||
|
|
||||||
|
# Compute the true gradient and the pina gradient
|
||||||
|
pina_grad = grad(output_=output_, input_=input_)
|
||||||
|
true_grad = func_grad(input_)
|
||||||
|
|
||||||
def test_grad_scalar_output():
|
# Check the shape and labels of the gradient
|
||||||
grad_tensor_s = grad(tensor_s, inp)
|
n_components = len(output_.labels) * len(input_.labels)
|
||||||
true_val = 2 * inp
|
assert pina_grad.shape == (*output_.shape[:-1], n_components)
|
||||||
true_val.labels = inp.labels
|
assert pina_grad.labels == [
|
||||||
assert grad_tensor_s.shape == inp.shape
|
f"d{c}d{i}" for c in output_.labels for i in input_.labels
|
||||||
assert grad_tensor_s.labels == [
|
|
||||||
f"d{tensor_s.labels[0]}d{i}" for i in inp.labels
|
|
||||||
]
|
]
|
||||||
assert torch.allclose(grad_tensor_s, true_val)
|
|
||||||
|
|
||||||
grad_tensor_s = grad(tensor_s, inp, d=["x", "y"])
|
# Compare the values
|
||||||
assert grad_tensor_s.shape == (20, 2)
|
assert torch.allclose(pina_grad, true_grad)
|
||||||
assert grad_tensor_s.labels == [
|
|
||||||
f"d{tensor_s.labels[0]}d{i}" for i in ["x", "y"]
|
# Test if labels are handled correctly
|
||||||
]
|
grad(output_=output_, input_=input_, components=output_.labels[0])
|
||||||
assert torch.allclose(grad_tensor_s, true_val.extract(["x", "y"]))
|
grad(output_=output_, input_=input_, d=input_.labels[0])
|
||||||
|
|
||||||
|
# Should fail if input not a LabelTensor
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
grad(output_=output_, input_=input_.tensor)
|
||||||
|
|
||||||
|
# Should fail if output not a LabelTensor
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
grad(output_=output_.tensor, input_=input_)
|
||||||
|
|
||||||
|
# Should fail for non-existent input labels
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
grad(output_=output_, input_=input_, d=["x", "y"])
|
||||||
|
|
||||||
|
# Should fail for non-existent output labels
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
grad(output_=output_, input_=input_, components=["a", "b", "c"])
|
||||||
|
|
||||||
|
|
||||||
def test_grad_vector_output():
|
@pytest.mark.parametrize(
|
||||||
grad_tensor_v = grad(tensor_v, inp)
|
"f",
|
||||||
true_val = torch.cat(
|
Function(),
|
||||||
(
|
ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"],
|
||||||
2 * inp.extract(["x"]),
|
)
|
||||||
torch.zeros_like(inp.extract(["y"])),
|
def test_divergence(f):
|
||||||
torch.zeros_like(inp.extract(["z"])),
|
|
||||||
torch.zeros_like(inp.extract(["x"])),
|
# Unpack the function
|
||||||
2 * inp.extract(["y"]),
|
func_input, func, _, func_div, _ = f
|
||||||
torch.zeros_like(inp.extract(["z"])),
|
|
||||||
torch.zeros_like(inp.extract(["x"])),
|
# Define input and output
|
||||||
torch.zeros_like(inp.extract(["y"])),
|
input_ = func_input()
|
||||||
2 * inp.extract(["z"]),
|
output_ = func(input_)
|
||||||
),
|
labels = [f"u{i}" for i in range(output_.shape[-1])]
|
||||||
dim=1,
|
output_ = LabelTensor(output_, labels)
|
||||||
|
|
||||||
|
# Scalar to vector or vector to scalar functions
|
||||||
|
if func_div(input_) == ValueError:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
div(output_=output_, input_=input_)
|
||||||
|
|
||||||
|
# Scalar to scalar or vector to vector functions
|
||||||
|
else:
|
||||||
|
# Compute the true divergence and the pina divergence
|
||||||
|
pina_div = div(output_=output_, input_=input_)
|
||||||
|
true_div = func_div(input_)
|
||||||
|
|
||||||
|
# Check the shape and labels of the divergence
|
||||||
|
assert pina_div.shape == (*output_.shape[:-1], 1)
|
||||||
|
tmp_labels = [
|
||||||
|
f"d{c}d{d_}" for c, d_ in zip(output_.labels, input_.labels)
|
||||||
|
]
|
||||||
|
assert pina_div.labels == ["+".join(tmp_labels)]
|
||||||
|
|
||||||
|
# Compare the values
|
||||||
|
assert torch.allclose(pina_div, true_div)
|
||||||
|
|
||||||
|
# Test if labels are handled correctly. Performed in a single call to
|
||||||
|
# avoid components and d having different lengths.
|
||||||
|
div(
|
||||||
|
output_=output_,
|
||||||
|
input_=input_,
|
||||||
|
components=output_.labels[0],
|
||||||
|
d=input_.labels[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if input not a LabelTensor
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
div(output_=output_, input_=input_.tensor)
|
||||||
|
|
||||||
|
# Should fail if output not a LabelTensor
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
div(output_=output_.tensor, input_=input_)
|
||||||
|
|
||||||
|
# Should fail for non-existent labels
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
div(output_=output_, input_=input_, d=["x", "y"])
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
div(output_=output_, input_=input_, components=["a", "b", "c"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"f",
|
||||||
|
Function(),
|
||||||
|
ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"],
|
||||||
|
)
|
||||||
|
def test_laplacian(f):
|
||||||
|
|
||||||
|
# Unpack the function
|
||||||
|
func_input, func, _, _, func_lap = f
|
||||||
|
|
||||||
|
# Define input and output
|
||||||
|
input_ = func_input()
|
||||||
|
output_ = func(input_)
|
||||||
|
labels = [f"u{i}" for i in range(output_.shape[-1])]
|
||||||
|
output_ = LabelTensor(output_, labels)
|
||||||
|
|
||||||
|
# Compute the true laplacian and the pina laplacian
|
||||||
|
pina_lap = laplacian(output_=output_, input_=input_)
|
||||||
|
true_lap = func_lap(input_)
|
||||||
|
|
||||||
|
# Check the shape and labels of the laplacian
|
||||||
|
assert pina_lap.shape == output_.shape
|
||||||
|
assert pina_lap.labels == [f"dd{l}" for l in output_.labels]
|
||||||
|
|
||||||
|
# Compare the values
|
||||||
|
assert torch.allclose(pina_lap, true_lap)
|
||||||
|
|
||||||
|
# Test if labels are handled correctly
|
||||||
|
laplacian(output_=output_, input_=input_, components=output_.labels[0])
|
||||||
|
laplacian(output_=output_, input_=input_, d=input_.labels[0])
|
||||||
|
|
||||||
|
# Should fail if input not a LabelTensor
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
laplacian(output_=output_, input_=input_.tensor)
|
||||||
|
|
||||||
|
# Should fail if output not a LabelTensor
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
laplacian(output_=output_.tensor, input_=input_)
|
||||||
|
|
||||||
|
# Should fail for non-existent input labels
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
laplacian(output_=output_, input_=input_, d=["x", "y"])
|
||||||
|
|
||||||
|
# Should fail for non-existent output labels
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
laplacian(output_=output_, input_=input_, components=["a", "b", "c"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_advection():
|
||||||
|
|
||||||
|
# Define input and output
|
||||||
|
input_ = torch.rand((20, 3), requires_grad=True)
|
||||||
|
input_ = LabelTensor(input_, ["x", "y", "z"])
|
||||||
|
output_ = LabelTensor(input_**2, ["u", "v", "c"])
|
||||||
|
|
||||||
|
# Define the velocity field
|
||||||
|
velocity = output_.extract(["c"])
|
||||||
|
|
||||||
|
# Compute the true advection and the pina advection
|
||||||
|
pina_advection = advection(
|
||||||
|
output_=output_, input_=input_, velocity_field="c"
|
||||||
)
|
)
|
||||||
assert grad_tensor_v.shape == (20, 9)
|
true_advection = velocity * 2 * input_.extract(["x", "y"])
|
||||||
assert grad_tensor_v.labels == [
|
|
||||||
f"d{j}d{i}" for j in tensor_v.labels for i in inp.labels
|
|
||||||
]
|
|
||||||
assert torch.allclose(grad_tensor_v, true_val)
|
|
||||||
|
|
||||||
grad_tensor_v = grad(tensor_v, inp, d=["x", "y"])
|
# Check the shape of the advection
|
||||||
true_val = torch.cat(
|
assert pina_advection.shape == (*output_.shape[:-1], output_.shape[-1] - 1)
|
||||||
(
|
assert torch.allclose(pina_advection, true_advection)
|
||||||
2 * inp.extract(["x"]),
|
|
||||||
torch.zeros_like(inp.extract(["y"])),
|
|
||||||
torch.zeros_like(inp.extract(["x"])),
|
|
||||||
2 * inp.extract(["y"]),
|
|
||||||
torch.zeros_like(inp.extract(["x"])),
|
|
||||||
torch.zeros_like(inp.extract(["y"])),
|
|
||||||
),
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
assert grad_tensor_v.shape == (inp.shape[0], 6)
|
|
||||||
assert grad_tensor_v.labels == [
|
|
||||||
f"d{j}d{i}" for j in tensor_v.labels for i in ["x", "y"]
|
|
||||||
]
|
|
||||||
assert torch.allclose(grad_tensor_v, true_val)
|
|
||||||
|
|
||||||
|
|
||||||
def test_div_vector_output():
|
|
||||||
div_tensor_v = div(tensor_v, inp)
|
|
||||||
true_val = 2 * torch.sum(inp, dim=1).reshape(-1, 1)
|
|
||||||
assert div_tensor_v.shape == (20, 1)
|
|
||||||
assert div_tensor_v.labels == [f"dadx+dbdy+dcdz"]
|
|
||||||
assert torch.allclose(div_tensor_v, true_val)
|
|
||||||
|
|
||||||
div_tensor_v = div(tensor_v, inp, components=["a", "b"], d=["x", "y"])
|
|
||||||
true_val = 2 * torch.sum(inp.extract(["x", "y"]), dim=1).reshape(-1, 1)
|
|
||||||
assert div_tensor_v.shape == (inp.shape[0], 1)
|
|
||||||
assert div_tensor_v.labels == [f"dadx+dbdy"]
|
|
||||||
assert torch.allclose(div_tensor_v, true_val)
|
|
||||||
|
|
||||||
|
|
||||||
def test_laplacian_scalar_output():
|
|
||||||
laplace_tensor_s = laplacian(tensor_s, inp)
|
|
||||||
true_val = 6 * torch.ones_like(laplace_tensor_s)
|
|
||||||
assert laplace_tensor_s.shape == tensor_s.shape
|
|
||||||
assert laplace_tensor_s.labels == [f"dd{tensor_s.labels[0]}"]
|
|
||||||
assert torch.allclose(laplace_tensor_s, true_val)
|
|
||||||
|
|
||||||
laplace_tensor_s = laplacian(tensor_s, inp, components=["a"], d=["x", "y"])
|
|
||||||
true_val = 4 * torch.ones_like(laplace_tensor_s)
|
|
||||||
assert laplace_tensor_s.shape == tensor_s.shape
|
|
||||||
assert laplace_tensor_s.labels == [f"dd{tensor_s.labels[0]}"]
|
|
||||||
assert torch.allclose(laplace_tensor_s, true_val)
|
|
||||||
|
|
||||||
|
|
||||||
def test_laplacian_vector_output():
|
|
||||||
laplace_tensor_v = laplacian(tensor_v, inp)
|
|
||||||
print(laplace_tensor_v.labels)
|
|
||||||
print(tensor_v.labels)
|
|
||||||
true_val = 2 * torch.ones_like(tensor_v)
|
|
||||||
assert laplace_tensor_v.shape == tensor_v.shape
|
|
||||||
assert laplace_tensor_v.labels == [f"dd{i}" for i in tensor_v.labels]
|
|
||||||
assert torch.allclose(laplace_tensor_v, true_val)
|
|
||||||
|
|
||||||
laplace_tensor_v = laplacian(
|
|
||||||
tensor_v, inp, components=["a", "b"], d=["x", "y"]
|
|
||||||
)
|
|
||||||
true_val = 2 * torch.ones_like(tensor_v.extract(["a", "b"]))
|
|
||||||
assert laplace_tensor_v.shape == tensor_v.extract(["a", "b"]).shape
|
|
||||||
assert laplace_tensor_v.labels == [f"dd{i}" for i in ["a", "b"]]
|
|
||||||
assert torch.allclose(laplace_tensor_v, true_val)
|
|
||||||
|
|
||||||
|
|
||||||
def test_laplacian_vector_output2():
|
|
||||||
x = LabelTensor(
|
|
||||||
torch.linspace(0, 1, 10, requires_grad=True).reshape(-1, 1),
|
|
||||||
labels=["x"],
|
|
||||||
)
|
|
||||||
y = LabelTensor(
|
|
||||||
torch.linspace(3, 4, 10, requires_grad=True).reshape(-1, 1),
|
|
||||||
labels=["y"],
|
|
||||||
)
|
|
||||||
input_ = LabelTensor(torch.cat((x, y), dim=1), labels=["x", "y"])
|
|
||||||
|
|
||||||
# Construct two scalar functions:
|
|
||||||
# u = x**2 + y**2
|
|
||||||
# v = x**2 - y**2
|
|
||||||
u = LabelTensor(
|
|
||||||
input_.extract("x") ** 2 + input_.extract("y") ** 2, labels="u"
|
|
||||||
)
|
|
||||||
v = LabelTensor(
|
|
||||||
input_.extract("x") ** 2 - input_.extract("y") ** 2, labels="v"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define a vector-valued function, whose components are u and v.
|
|
||||||
f = LabelTensor(torch.cat((u, v), dim=1), labels=["u", "v"])
|
|
||||||
|
|
||||||
# Compute the scalar laplacian of both u and v:
|
|
||||||
# Lap(u) = [4, 4, 4, ..., 4]
|
|
||||||
# Lap(v) = [0, 0, 0, ..., 0]
|
|
||||||
lap_u = laplacian(u, input_, components=["u"])
|
|
||||||
lap_v = laplacian(v, input_, components=["v"])
|
|
||||||
|
|
||||||
# Compute the laplacian of f: the two columns should correspond
|
|
||||||
# to the laplacians of u and v, respectively...
|
|
||||||
lap_f = laplacian(f, input_, components=["u", "v"])
|
|
||||||
|
|
||||||
assert torch.allclose(lap_f.extract("ddu"), lap_u)
|
|
||||||
assert torch.allclose(lap_f.extract("ddv"), lap_v)
|
|
||||||
|
|
||||||
|
|
||||||
def test_label_format():
|
|
||||||
# Testing the format of `components` or `d` in case of single str of length
|
|
||||||
# greater than 1; e.g.: "aaa".
|
|
||||||
# This test is conducted only for gradient and laplacian, since div is not
|
|
||||||
# implemented for single components.
|
|
||||||
inp.labels = ["xx", "yy", "zz"]
|
|
||||||
tensor_v = LabelTensor(func_vector(inp), ["aa", "bbb", "c"])
|
|
||||||
comp = tensor_v.labels[0]
|
|
||||||
single_d = inp.labels[0]
|
|
||||||
|
|
||||||
# Single component as string + list of d
|
|
||||||
grad_tensor_v = grad(tensor_v, inp, components=comp, d=None)
|
|
||||||
assert grad_tensor_v.labels == [f"d{comp}d{i}" for i in inp.labels]
|
|
||||||
|
|
||||||
lap_tensor_v = laplacian(tensor_v, inp, components=comp, d=None)
|
|
||||||
assert lap_tensor_v.labels == [f"dd{comp}"]
|
|
||||||
|
|
||||||
# Single component as list + list of d
|
|
||||||
grad_tensor_v = grad(tensor_v, inp, components=[comp], d=None)
|
|
||||||
assert grad_tensor_v.labels == [f"d{comp}d{i}" for i in inp.labels]
|
|
||||||
|
|
||||||
lap_tensor_v = laplacian(tensor_v, inp, components=[comp], d=None)
|
|
||||||
assert lap_tensor_v.labels == [f"dd{comp}"]
|
|
||||||
|
|
||||||
# List of components + single d as string
|
|
||||||
grad_tensor_v = grad(tensor_v, inp, components=None, d=single_d)
|
|
||||||
assert grad_tensor_v.labels == [f"d{i}d{single_d}" for i in tensor_v.labels]
|
|
||||||
|
|
||||||
lap_tensor_v = laplacian(tensor_v, inp, components=None, d=single_d)
|
|
||||||
assert lap_tensor_v.labels == [f"dd{i}" for i in tensor_v.labels]
|
|
||||||
|
|
||||||
# List of components + single d as list
|
|
||||||
grad_tensor_v = grad(tensor_v, inp, components=None, d=[single_d])
|
|
||||||
assert grad_tensor_v.labels == [f"d{i}d{single_d}" for i in tensor_v.labels]
|
|
||||||
|
|
||||||
lap_tensor_v = laplacian(tensor_v, inp, components=None, d=[single_d])
|
|
||||||
assert lap_tensor_v.labels == [f"dd{i}" for i in tensor_v.labels]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user