Files
PINA/pina/model/deeponet.py
2022-03-29 18:05:26 +02:00

92 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Module for DeepONet model"""
import torch
import torch.nn as nn
from pina.label_tensor import LabelTensor
class DeepONet(torch.nn.Module):
"""
The PINA implementation of DeepONet network.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al. *Learning
nonlinear operators via DeepONet based on the universal approximation
theorem of operators*. Nat Mach Intell 3, 218229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
"""
def __init__(self, branch_net, trunk_net, output_variables):
"""
:param torch.nn.Module branch_net: the neural network to use as branch
model. It has to take as input a :class:`LabelTensor`. The number
of dimension of the output has to be the same of the `trunk_net`.
:param torch.nn.Module trunk_net: the neural network to use as trunk
model. It has to take as input a :class:`LabelTensor`. The number
of dimension of the output has to be the same of the `branch_net`.
:param list(str) output_variables: the list containing the labels
corresponding to the components of the output computed by the
model.
:Example:
>>> branch = FFN(input_variables=['a', 'c'], output_variables=20)
>>> trunk = FFN(input_variables=['b'], output_variables=20)
>>> onet = DeepONet(trunk_net=trunk, branch_net=branch
>>> output_variables=output_vars)
DeepONet(
(trunk_net): FeedForward(
(extra_features): Sequential()
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(branch_net): FeedForward(
(extra_features): Sequential()
(model): Sequential(
(0): Linear(in_features=2, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
)
"""
super().__init__()
self.trunk_net = trunk_net
self.branch_net = branch_net
self.output_variables = output_variables
self.output_dimension = len(output_variables)
if self.output_dimension > 1:
raise NotImplementedError('Vectorial DeepONet to be implemented')
@property
def input_variables(self):
"""The input variables of the model"""
return self.trunk_net.input_variables + self.branch_net.input_variables
def forward(self, x):
"""
Defines the computation performed at every call.
:param LabelTensor x: the input tensor.
:return: the output computed by the model.
:rtype: LabelTensor
"""
branch_output = self.branch_net(
x.extract(self.branch_net.input_variables))
trunk_output = self.trunk_net(
x.extract(self.trunk_net.input_variables))
output_ = torch.sum(branch_output * trunk_output, dim=1).reshape(-1, 1)
return LabelTensor(output_, self.output_variables)