use LabelTensor, fix minor, docs
This commit is contained in:
91
pina/model/deeponet.py
Normal file
91
pina/model/deeponet.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Module for DeepONet model"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pina.label_tensor import LabelTensor
|
||||
|
||||
|
||||
class DeepONet(torch.nn.Module):
|
||||
"""
|
||||
The PINA implementation of DeepONet network.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
|
||||
nonlinear operators via DeepONet based on the universal approximation
|
||||
theorem of operators*. Nat Mach Intell 3, 218–229 (2021).
|
||||
DOI: `10.1038/s42256-021-00302-5
|
||||
<https://doi.org/10.1038/s42256-021-00302-5>`_
|
||||
|
||||
"""
|
||||
def __init__(self, branch_net, trunk_net, output_variables):
|
||||
"""
|
||||
:param torch.nn.Module branch_net: the neural network to use as branch
|
||||
model. It has to take as input a :class:`LabelTensor`. The number
|
||||
of dimension of the output has to be the same of the `trunk_net`.
|
||||
:param torch.nn.Module trunk_net: the neural network to use as trunk
|
||||
model. It has to take as input a :class:`LabelTensor`. The number
|
||||
of dimension of the output has to be the same of the `branch_net`.
|
||||
:param list(str) output_variables: the list containing the labels
|
||||
corresponding to the components of the output computed by the
|
||||
model.
|
||||
|
||||
:Example:
|
||||
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
|
||||
>>> trunk = FFN(input_variables=['b'], output_variables=20)
|
||||
>>> onet = DeepONet(trunk_net=trunk, branch_net=branch
|
||||
>>> output_variables=output_vars)
|
||||
DeepONet(
|
||||
(trunk_net): FeedForward(
|
||||
(extra_features): Sequential()
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=1, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
)
|
||||
)
|
||||
(branch_net): FeedForward(
|
||||
(extra_features): Sequential()
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=2, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=20, bias=True)
|
||||
(3): Tanh()
|
||||
(4): Linear(in_features=20, out_features=10, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.trunk_net = trunk_net
|
||||
self.branch_net = branch_net
|
||||
|
||||
self.output_variables = output_variables
|
||||
self.output_dimension = len(output_variables)
|
||||
if self.output_dimension > 1:
|
||||
raise NotImplementedError('Vectorial DeepONet to be implemented')
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
"""The input variables of the model"""
|
||||
return self.trunk_net.input_variables + self.branch_net.input_variables
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
|
||||
:param LabelTensor x: the input tensor.
|
||||
:return: the output computed by the model.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
branch_output = self.branch_net(
|
||||
x.extract(self.branch_net.input_variables))
|
||||
trunk_output = self.trunk_net(
|
||||
x.extract(self.trunk_net.input_variables))
|
||||
|
||||
output_ = torch.sum(branch_output * trunk_output, dim=1).reshape(-1, 1)
|
||||
|
||||
return LabelTensor(output_, self.output_variables)
|
||||
Reference in New Issue
Block a user