Low Rank Neural Operator (#270)
* add the Low Rank Neural Operator as Model * add the Low Rank Layer as Layer * adding tests * adding doc --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
@@ -57,6 +57,7 @@ Models
|
||||
FourierIntegralKernel <models/fourier_kernel.rst>
|
||||
FNO <models/fno.rst>
|
||||
AveragingNeuralOperator <models/avno.rst>
|
||||
LowRankNeuralOperator <models/lno.rst>
|
||||
|
||||
Layers
|
||||
-------------
|
||||
@@ -69,6 +70,7 @@ Layers
|
||||
Spectral convolution <layers/spectral.rst>
|
||||
Fourier layers <layers/fourier.rst>
|
||||
Averaging layer <layers/avno_layer.rst>
|
||||
Low Rank layer <layers/lowrank_layer.rst>
|
||||
Continuous convolution <layers/convolution.rst>
|
||||
Proper Orthogonal Decomposition <layers/pod.rst>
|
||||
Periodic Boundary Condition embeddings <layers/embedding.rst>
|
||||
|
||||
8
docs/source/_rst/layers/lowrank_layer.rst
Normal file
8
docs/source/_rst/layers/lowrank_layer.rst
Normal file
@@ -0,0 +1,8 @@
|
||||
Low Rank layer
|
||||
====================
|
||||
.. currentmodule:: pina.model.layers.lowrank_layer
|
||||
|
||||
.. autoclass:: LowRankBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
7
docs/source/_rst/models/lno.rst
Normal file
7
docs/source/_rst/models/lno.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
Low Rank Neural Operator
|
||||
==============================
|
||||
.. currentmodule:: pina.model.lno
|
||||
|
||||
.. autoclass:: LowRankNeuralOperator
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -8,6 +8,7 @@ __all__ = [
|
||||
"FourierIntegralKernel",
|
||||
"KernelNeuralOperator",
|
||||
"AveragingNeuralOperator",
|
||||
"LowRankNeuralOperator"
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward, ResidualFeedForward
|
||||
@@ -16,3 +17,4 @@ from .deeponet import DeepONet, MIONet
|
||||
from .fno import FNO, FourierIntegralKernel
|
||||
from .base_no import KernelNeuralOperator
|
||||
from .avno import AveragingNeuralOperator
|
||||
from .lno import LowRankNeuralOperator
|
||||
|
||||
@@ -11,6 +11,7 @@ __all__ = [
|
||||
"PODBlock",
|
||||
"PeriodicBoundaryEmbedding",
|
||||
"AVNOBlock",
|
||||
"LowRankBlock",
|
||||
"AdaptiveActivationFunction",
|
||||
]
|
||||
|
||||
@@ -25,4 +26,5 @@ from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||
from .pod import PODBlock
|
||||
from .embedding import PeriodicBoundaryEmbedding
|
||||
from .avno_layer import AVNOBlock
|
||||
from .lowrank_layer import LowRankBlock
|
||||
from .adaptive_func import AdaptiveActivationFunction
|
||||
135
pina/model/layers/lowrank_layer.py
Normal file
135
pina/model/layers/lowrank_layer.py
Normal file
@@ -0,0 +1,135 @@
|
||||
""" Module for Averaging Neural Operator Layer class. """
|
||||
|
||||
import torch
|
||||
|
||||
from pina.utils import check_consistency
|
||||
import pina.model as pm # avoid circular import
|
||||
|
||||
|
||||
class LowRankBlock(torch.nn.Module):
|
||||
r"""
|
||||
The PINA implementation of the inner layer of the Averaging Neural Operator.
|
||||
|
||||
The operator layer performs an affine transformation where the convolution
|
||||
is approximated with a local average. Given the input function
|
||||
:math:`v(x)\in\mathbb{R}^{\rm{emb}}` the layer computes
|
||||
the operator update :math:`K(v)` as:
|
||||
|
||||
.. math::
|
||||
K(v) = \sigma\left(Wv(x) + b + \sum_{i=1}^r \langle
|
||||
\psi^{(i)} , v(x) \rangle \phi^{(i)} \right)
|
||||
|
||||
where:
|
||||
|
||||
* :math:`\mathbb{R}^{\rm{emb}}` is the embedding (hidden) size
|
||||
corresponding to the ``hidden_size`` object
|
||||
* :math:`\sigma` is a non-linear activation, corresponding to the
|
||||
``func`` object
|
||||
* :math:`W\in\mathbb{R}^{\rm{emb}\times\rm{emb}}` is a tunable matrix.
|
||||
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
|
||||
* :math:`\psi^{(i)}\in\mathbb{R}^{\rm{emb}}` and
|
||||
:math:`\phi^{(i)}\in\mathbb{R}^{\rm{emb}}` are :math:`r` a low rank
|
||||
basis functions mapping.
|
||||
* :math:`b\in\mathbb{R}^{\rm{emb}}` is a tunable bias.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Kovachki, N., Li, Z., Liu, B.,
|
||||
Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A.
|
||||
(2023). *Neural operator: Learning maps between function
|
||||
spaces with applications to PDEs*. Journal of Machine Learning
|
||||
Research, 24(89), 1-97.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dimensions,
|
||||
embedding_dimenion,
|
||||
rank,
|
||||
inner_size=20,
|
||||
n_layers=2,
|
||||
func=torch.nn.Tanh,
|
||||
bias=True):
|
||||
"""
|
||||
:param int input_dimensions: The number of input components of the
|
||||
model.
|
||||
Expected tensor shape of the form :math:`(*, d)`, where *
|
||||
means any number of dimensions including none,
|
||||
and :math:`d` the ``input_dimensions``.
|
||||
:param int embedding_dimenion: Size of the embedding dimension of the
|
||||
field.
|
||||
:param int rank: The rank number of the basis approximation components
|
||||
of the model. Expected tensor shape of the form :math:`(*, 2d)`,
|
||||
where * means any number of dimensions including none,
|
||||
and :math:`2d` the ``rank`` for both basis functions.
|
||||
:param int inner_size: Number of neurons in the hidden layer(s) for the
|
||||
basis function network. Default is 20.
|
||||
:param int n_layers: Number of hidden layers. for the
|
||||
basis function network. Default is 2.
|
||||
:param func: The activation function to use for the
|
||||
basis function network. If a single
|
||||
:class:`torch.nn.Module` is passed, this is used as
|
||||
activation function after any layers, except the last one.
|
||||
If a list of Modules is passed,
|
||||
they are used as activation functions at any layers, in order.
|
||||
:param bool bias: If ``True`` the MLP will consider some bias for the
|
||||
basis function network.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Assignment (check consistency inside FeedForward)
|
||||
self._basis = pm.FeedForward(input_dimensions=input_dimensions,
|
||||
output_dimensions=2*rank*embedding_dimenion,
|
||||
inner_size=inner_size, n_layers=n_layers,
|
||||
func=func, bias=bias)
|
||||
self._nn = torch.nn.Linear(embedding_dimenion, embedding_dimenion)
|
||||
|
||||
check_consistency(rank, int)
|
||||
self._rank = rank
|
||||
self._func = func()
|
||||
|
||||
def forward(self, x, coords):
|
||||
r"""
|
||||
Forward pass of the layer, it performs an affine transformation of
|
||||
the field, and a low rank approximation by
|
||||
doing a dot product of the basis
|
||||
:math:`\psi^{(i)}` with the filed vector :math:`v`, and use this
|
||||
coefficients to expand :math:`\phi^{(i)}` evaluated in the
|
||||
spatial input :math:`x`.
|
||||
|
||||
:param torch.Tensor x: The input tensor for performing the
|
||||
computation. It expects a tensor :math:`B \times N \times D`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem. In particular
|
||||
:math:`D` is the codomain of the function :math:`v`. For example
|
||||
a scalar function has :math:`D=1`, a 4-dimensional vector function
|
||||
:math:`D=4`.
|
||||
:param torch.Tensor coords: The coordinates in which the field is
|
||||
evaluated for performing the computation. It expects a
|
||||
tensor :math:`B \times N \times d`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the domain.
|
||||
:return: The output tensor obtained from Average Neural Operator Block.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract basis
|
||||
basis = self._basis(coords)
|
||||
# reshape [B, N, D, 2*rank]
|
||||
shape = list(basis.shape[:-1]) + [-1, 2*self.rank]
|
||||
basis = basis.reshape(shape)
|
||||
# divide
|
||||
psi = basis[..., :self.rank]
|
||||
phi = basis[..., self.rank:]
|
||||
# compute dot product
|
||||
coeff = torch.einsum('...dr,...d->...r', psi,x)
|
||||
# expand the basis
|
||||
expansion = torch.einsum('...r,...dr->...d', coeff,phi)
|
||||
# apply linear layer and return
|
||||
return self._func(self._nn(x) + expansion)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
The basis rank.
|
||||
"""
|
||||
return self._rank
|
||||
143
pina/model/lno.py
Normal file
143
pina/model/lno.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Module LowRank Neural Operator."""
|
||||
|
||||
import torch
|
||||
from torch import nn, concatenate
|
||||
|
||||
from pina.utils import check_consistency
|
||||
|
||||
from .base_no import KernelNeuralOperator
|
||||
from .layers.lowrank_layer import LowRankBlock
|
||||
|
||||
|
||||
class LowRankNeuralOperator(KernelNeuralOperator):
|
||||
"""
|
||||
Implementation of LowRank Neural Operator.
|
||||
|
||||
LowRank Neural Operator is a general architecture for
|
||||
learning Operators. Unlike traditional machine learning methods
|
||||
LowRankNeuralOperator is designed to map entire functions
|
||||
to other functions. It can be trained with Supervised or PINN based
|
||||
learning strategies.
|
||||
LowRankNeuralOperator does convolution by performing a low rank
|
||||
approximation, see :class:`~pina.model.layers.lowrank_layer.LowRankBlock`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Kovachki, N., Li, Z., Liu, B.,
|
||||
Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A.
|
||||
(2023). *Neural operator: Learning maps between function
|
||||
spaces with applications to PDEs*. Journal of Machine Learning
|
||||
Research, 24(89), 1-97.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifting_net,
|
||||
projecting_net,
|
||||
field_indices,
|
||||
coordinates_indices,
|
||||
n_kernel_layers,
|
||||
rank,
|
||||
inner_size=20,
|
||||
n_layers=2,
|
||||
func=torch.nn.Tanh,
|
||||
bias=True
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module lifting_net: The neural network for lifting
|
||||
the input. It must take as input the input field and the coordinates
|
||||
at which the input field is avaluated. The output of the lifting
|
||||
net is chosen as embedding dimension of the problem
|
||||
:param torch.nn.Module projecting_net: The neural network for
|
||||
projecting the output. It must take as input the embedding dimension
|
||||
(output of the ``lifting_net``) plus the dimension
|
||||
of the coordinates.
|
||||
:param list[str] field_indices: the label of the fields
|
||||
in the input tensor.
|
||||
:param list[str] coordinates_indices: the label of the
|
||||
coordinates in the input tensor.
|
||||
:param int n_kernel_layers: number of hidden kernel layers.
|
||||
Default is 4.
|
||||
:param int inner_size: Number of neurons in the hidden layer(s) for the
|
||||
basis function network. Default is 20.
|
||||
:param int n_layers: Number of hidden layers. for the
|
||||
basis function network. Default is 2.
|
||||
:param func: The activation function to use for the
|
||||
basis function network. If a single
|
||||
:class:`torch.nn.Module` is passed, this is used as
|
||||
activation function after any layers, except the last one.
|
||||
If a list of Modules is passed,
|
||||
they are used as activation functions at any layers, in order.
|
||||
:param bool bias: If ``True`` the MLP will consider some bias for the
|
||||
basis function network.
|
||||
"""
|
||||
|
||||
# check consistency
|
||||
check_consistency(field_indices, str)
|
||||
check_consistency(coordinates_indices, str)
|
||||
check_consistency(n_kernel_layers, int)
|
||||
|
||||
# check hidden dimensions match
|
||||
input_lifting_net = next(lifting_net.parameters()).size()[-1]
|
||||
output_lifting_net = lifting_net(
|
||||
torch.rand(size=next(lifting_net.parameters()).size())
|
||||
).shape[-1]
|
||||
projecting_net_input = next(projecting_net.parameters()).size()[-1]
|
||||
|
||||
if len(field_indices) + len(coordinates_indices) != input_lifting_net:
|
||||
raise ValueError(
|
||||
"The lifting_net must take as input the "
|
||||
"coordinates vector and the field vector."
|
||||
)
|
||||
|
||||
if (
|
||||
output_lifting_net + len(coordinates_indices)
|
||||
!= projecting_net_input
|
||||
):
|
||||
raise ValueError(
|
||||
"The projecting_net input must be equal to "
|
||||
"the embedding dimension (which is the output) "
|
||||
"of the lifting_net plus the dimension of the "
|
||||
"coordinates, i.e. len(coordinates_indices)."
|
||||
)
|
||||
|
||||
# assign
|
||||
self.coordinates_indices = coordinates_indices
|
||||
self.field_indices = field_indices
|
||||
integral_net = nn.Sequential(
|
||||
*[LowRankBlock(input_dimensions=len(coordinates_indices),
|
||||
embedding_dimenion=output_lifting_net,
|
||||
rank=rank,
|
||||
inner_size=inner_size,
|
||||
n_layers=n_layers,
|
||||
func=func,
|
||||
bias=bias) for _ in range(n_kernel_layers)]
|
||||
)
|
||||
super().__init__(lifting_net, integral_net, projecting_net)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward computation for LowRank Neural Operator. It performs a
|
||||
lifting of the input by the ``lifting_net``. Then different layers
|
||||
of LowRank Neural Operator Blocks are applied.
|
||||
Finally the output is projected to the final dimensionality
|
||||
by the ``projecting_net``.
|
||||
|
||||
:param torch.Tensor x: The input tensor for fourier block,
|
||||
depending on ``dimension`` in the initialization. It expects
|
||||
a tensor :math:`B \times N \times D`,
|
||||
where :math:`B` is the batch_size, :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem, i.e. the sum
|
||||
of ``len(coordinates_indices)+len(field_indices)``.
|
||||
:return: The output tensor obtained from Average Neural Operator.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract points
|
||||
coords = x.extract(self.coordinates_indices)
|
||||
# lifting
|
||||
x = self._lifting_operator(x)
|
||||
# kernel
|
||||
for module in self._integral_kernels:
|
||||
x = module(x, coords)
|
||||
# projecting
|
||||
return self._projection_operator(concatenate((x, coords), dim=-1))
|
||||
58
tests/test_layers/test_lnolayer.py
Normal file
58
tests/test_layers/test_lnolayer.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina.model.layers import LowRankBlock
|
||||
from pina import LabelTensor
|
||||
|
||||
|
||||
input_dimensions=2
|
||||
embedding_dimenion=1
|
||||
rank=4
|
||||
inner_size=20
|
||||
n_layers=2
|
||||
func=torch.nn.Tanh
|
||||
bias=True
|
||||
|
||||
def test_constructor():
|
||||
LowRankBlock(input_dimensions=input_dimensions,
|
||||
embedding_dimenion=embedding_dimenion,
|
||||
rank=rank,
|
||||
inner_size=inner_size,
|
||||
n_layers=n_layers,
|
||||
func=func,
|
||||
bias=bias)
|
||||
|
||||
def test_constructor_wrong():
|
||||
with pytest.raises(ValueError):
|
||||
LowRankBlock(input_dimensions=input_dimensions,
|
||||
embedding_dimenion=embedding_dimenion,
|
||||
rank=0.5,
|
||||
inner_size=inner_size,
|
||||
n_layers=n_layers,
|
||||
func=func,
|
||||
bias=bias)
|
||||
|
||||
def test_forward():
|
||||
block = LowRankBlock(input_dimensions=input_dimensions,
|
||||
embedding_dimenion=embedding_dimenion,
|
||||
rank=rank,
|
||||
inner_size=inner_size,
|
||||
n_layers=n_layers,
|
||||
func=func,
|
||||
bias=bias)
|
||||
data = LabelTensor(torch.rand(10, 30, 3), labels=['x', 'y', 'u'])
|
||||
block(data.extract('u'), data.extract(['x', 'y']))
|
||||
|
||||
def test_backward():
|
||||
block = LowRankBlock(input_dimensions=input_dimensions,
|
||||
embedding_dimenion=embedding_dimenion,
|
||||
rank=rank,
|
||||
inner_size=inner_size,
|
||||
n_layers=n_layers,
|
||||
func=func,
|
||||
bias=bias)
|
||||
data = LabelTensor(torch.rand(10, 30, 3), labels=['x', 'y', 'u'])
|
||||
data.requires_grad_(True)
|
||||
out = block(data.extract('u'), data.extract(['x', 'y']))
|
||||
loss = out.mean()
|
||||
loss.backward()
|
||||
141
tests/test_model/test_lno.py
Normal file
141
tests/test_model/test_lno.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from pina.model import LowRankNeuralOperator
|
||||
from pina import LabelTensor
|
||||
import pytest
|
||||
|
||||
|
||||
batch_size = 15
|
||||
n_layers = 4
|
||||
embedding_dim = 24
|
||||
func = torch.nn.Tanh
|
||||
rank = 4
|
||||
n_kernel_layers = 3
|
||||
field_indices = ['u']
|
||||
coordinates_indices = ['x', 'y']
|
||||
|
||||
def test_constructor():
|
||||
# working constructor
|
||||
lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices),
|
||||
embedding_dim)
|
||||
projecting_net = torch.nn.Linear(embedding_dim + len(coordinates_indices),
|
||||
len(field_indices))
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
# not working constructor
|
||||
with pytest.raises(ValueError):
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=3.2, # wrong
|
||||
rank=rank)
|
||||
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=[0], # wrong
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=[0], # wront
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=[0], #wrong
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=[0], #wrong
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
lifting_net = torch.nn.Linear(len(coordinates_indices),
|
||||
embedding_dim)
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices),
|
||||
embedding_dim)
|
||||
projecting_net = torch.nn.Linear(embedding_dim,
|
||||
len(field_indices))
|
||||
LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
|
||||
def test_forward():
|
||||
lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices),
|
||||
embedding_dim)
|
||||
projecting_net = torch.nn.Linear(embedding_dim + len(coordinates_indices),
|
||||
len(field_indices))
|
||||
lno = LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
|
||||
input_ = LabelTensor(
|
||||
torch.rand(batch_size, 100,
|
||||
len(coordinates_indices) + len(field_indices)),
|
||||
coordinates_indices + field_indices)
|
||||
|
||||
out = lno(input_)
|
||||
assert out.shape == torch.Size(
|
||||
[batch_size, input_.shape[1], len(field_indices)])
|
||||
|
||||
|
||||
def test_backward():
|
||||
lifting_net = torch.nn.Linear(len(coordinates_indices) + len(field_indices),
|
||||
embedding_dim)
|
||||
projecting_net = torch.nn.Linear(embedding_dim + len(coordinates_indices),
|
||||
len(field_indices))
|
||||
lno=LowRankNeuralOperator(
|
||||
lifting_net=lifting_net,
|
||||
projecting_net=projecting_net,
|
||||
coordinates_indices=coordinates_indices,
|
||||
field_indices=field_indices,
|
||||
n_kernel_layers=n_kernel_layers,
|
||||
rank=rank)
|
||||
input_ = LabelTensor(
|
||||
torch.rand(batch_size, 100,
|
||||
len(coordinates_indices) + len(field_indices)),
|
||||
coordinates_indices + field_indices)
|
||||
input_ = input_.requires_grad_()
|
||||
out = lno(input_)
|
||||
tmp = torch.linalg.norm(out)
|
||||
tmp.backward()
|
||||
grad = input_.grad
|
||||
assert grad.shape == torch.Size(
|
||||
[batch_size, input_.shape[1],
|
||||
len(coordinates_indices) + len(field_indices)])
|
||||
Reference in New Issue
Block a user