🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module for DeepONet model"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency, is_function
|
||||
@@ -24,12 +25,14 @@ class MIONet(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
networks,
|
||||
aggregator="*",
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True):
|
||||
def __init__(
|
||||
self,
|
||||
networks,
|
||||
aggregator="*",
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True,
|
||||
):
|
||||
"""
|
||||
:param dict networks: The neural networks to use as
|
||||
models. The ``dict`` takes as key a neural network, and
|
||||
@@ -121,8 +124,9 @@ class MIONet(torch.nn.Module):
|
||||
shapes.append(key(input_).shape[-1])
|
||||
|
||||
if not all(map(lambda x: x == shapes[0], shapes)):
|
||||
raise ValueError('The passed networks have not the same '
|
||||
'output dimension.')
|
||||
raise ValueError(
|
||||
"The passed networks have not the same " "output dimension."
|
||||
)
|
||||
|
||||
# assign trunk and branch net with their input indeces
|
||||
self.models = torch.nn.ModuleList(networks.keys())
|
||||
@@ -133,10 +137,16 @@ class MIONet(torch.nn.Module):
|
||||
self._init_reduction(reduction=reduction)
|
||||
|
||||
# scale and translation
|
||||
self._scale = torch.nn.Parameter(torch.tensor(
|
||||
[1.0])) if scale else torch.tensor([1.0])
|
||||
self._trasl = torch.nn.Parameter(torch.tensor(
|
||||
[1.0])) if translation else torch.tensor([1.0])
|
||||
self._scale = (
|
||||
torch.nn.Parameter(torch.tensor([1.0]))
|
||||
if scale
|
||||
else torch.tensor([1.0])
|
||||
)
|
||||
self._trasl = (
|
||||
torch.nn.Parameter(torch.tensor([1.0]))
|
||||
if translation
|
||||
else torch.tensor([1.0])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _symbol_functions(**kwargs):
|
||||
@@ -180,16 +190,18 @@ class MIONet(torch.nn.Module):
|
||||
return x.extract(indeces)
|
||||
except AttributeError:
|
||||
raise RuntimeError(
|
||||
'Not possible to extract input variables from tensor.'
|
||||
' Ensure that the passed tensor is a LabelTensor or'
|
||||
' pass list of integers to extract variables. For'
|
||||
' more information refer to warning in the documentation.')
|
||||
"Not possible to extract input variables from tensor."
|
||||
" Ensure that the passed tensor is a LabelTensor or"
|
||||
" pass list of integers to extract variables. For"
|
||||
" more information refer to warning in the documentation."
|
||||
)
|
||||
elif isinstance(indeces[0], int):
|
||||
return x[..., indeces]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Not able to extract right indeces for tensor.'
|
||||
' For more information refer to warning in the documentation.')
|
||||
"Not able to extract right indeces for tensor."
|
||||
" For more information refer to warning in the documentation."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -197,7 +209,7 @@ class MIONet(torch.nn.Module):
|
||||
|
||||
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
|
||||
:return: The output computed by the DeepONet model.
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
"""
|
||||
|
||||
# forward pass
|
||||
@@ -267,7 +279,7 @@ class DeepONet(MIONet):
|
||||
|
||||
DeepONet is a general architecture for learning Operators. Unlike
|
||||
traditional machine learning methods DeepONet is designed to map
|
||||
entire functions to other functions. It can be trained both with
|
||||
entire functions to other functions. It can be trained both with
|
||||
Physics Informed or Supervised learning strategies.
|
||||
|
||||
.. seealso::
|
||||
@@ -280,15 +292,17 @@ class DeepONet(MIONet):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
branch_net,
|
||||
trunk_net,
|
||||
input_indeces_branch_net,
|
||||
input_indeces_trunk_net,
|
||||
aggregator="*",
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True):
|
||||
def __init__(
|
||||
self,
|
||||
branch_net,
|
||||
trunk_net,
|
||||
input_indeces_branch_net,
|
||||
input_indeces_trunk_net,
|
||||
aggregator="*",
|
||||
reduction="+",
|
||||
scale=True,
|
||||
translation=True,
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module branch_net: The neural network to use as branch
|
||||
model. It has to take as input a :py:obj:`pina.label_tensor.LabelTensor`
|
||||
@@ -363,14 +377,15 @@ class DeepONet(MIONet):
|
||||
"""
|
||||
networks = {
|
||||
branch_net: input_indeces_branch_net,
|
||||
trunk_net: input_indeces_trunk_net
|
||||
trunk_net: input_indeces_trunk_net,
|
||||
}
|
||||
super().__init__(networks=networks,
|
||||
aggregator=aggregator,
|
||||
reduction=reduction,
|
||||
scale=scale,
|
||||
translation=translation)
|
||||
|
||||
super().__init__(
|
||||
networks=networks,
|
||||
aggregator=aggregator,
|
||||
reduction=reduction,
|
||||
scale=scale,
|
||||
translation=translation,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -378,11 +393,10 @@ class DeepONet(MIONet):
|
||||
|
||||
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
|
||||
:return: The output computed by the DeepONet model.
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
:rtype: LabelTensor or torch.Tensor
|
||||
"""
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@property
|
||||
def branch_net(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user