🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,4 +1,5 @@
"""Module for DeepONet model"""
import torch
import torch.nn as nn
from ..utils import check_consistency, is_function
@@ -24,12 +25,14 @@ class MIONet(torch.nn.Module):
"""
def __init__(self,
networks,
aggregator="*",
reduction="+",
scale=True,
translation=True):
def __init__(
self,
networks,
aggregator="*",
reduction="+",
scale=True,
translation=True,
):
"""
:param dict networks: The neural networks to use as
models. The ``dict`` takes as key a neural network, and
@@ -121,8 +124,9 @@ class MIONet(torch.nn.Module):
shapes.append(key(input_).shape[-1])
if not all(map(lambda x: x == shapes[0], shapes)):
raise ValueError('The passed networks have not the same '
'output dimension.')
raise ValueError(
"The passed networks have not the same " "output dimension."
)
# assign trunk and branch net with their input indeces
self.models = torch.nn.ModuleList(networks.keys())
@@ -133,10 +137,16 @@ class MIONet(torch.nn.Module):
self._init_reduction(reduction=reduction)
# scale and translation
self._scale = torch.nn.Parameter(torch.tensor(
[1.0])) if scale else torch.tensor([1.0])
self._trasl = torch.nn.Parameter(torch.tensor(
[1.0])) if translation else torch.tensor([1.0])
self._scale = (
torch.nn.Parameter(torch.tensor([1.0]))
if scale
else torch.tensor([1.0])
)
self._trasl = (
torch.nn.Parameter(torch.tensor([1.0]))
if translation
else torch.tensor([1.0])
)
@staticmethod
def _symbol_functions(**kwargs):
@@ -180,16 +190,18 @@ class MIONet(torch.nn.Module):
return x.extract(indeces)
except AttributeError:
raise RuntimeError(
'Not possible to extract input variables from tensor.'
' Ensure that the passed tensor is a LabelTensor or'
' pass list of integers to extract variables. For'
' more information refer to warning in the documentation.')
"Not possible to extract input variables from tensor."
" Ensure that the passed tensor is a LabelTensor or"
" pass list of integers to extract variables. For"
" more information refer to warning in the documentation."
)
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
raise RuntimeError(
'Not able to extract right indeces for tensor.'
' For more information refer to warning in the documentation.')
"Not able to extract right indeces for tensor."
" For more information refer to warning in the documentation."
)
def forward(self, x):
"""
@@ -197,7 +209,7 @@ class MIONet(torch.nn.Module):
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
:rtype: LabelTensor or torch.Tensor
"""
# forward pass
@@ -267,7 +279,7 @@ class DeepONet(MIONet):
DeepONet is a general architecture for learning Operators. Unlike
traditional machine learning methods DeepONet is designed to map
entire functions to other functions. It can be trained both with
entire functions to other functions. It can be trained both with
Physics Informed or Supervised learning strategies.
.. seealso::
@@ -280,15 +292,17 @@ class DeepONet(MIONet):
"""
def __init__(self,
branch_net,
trunk_net,
input_indeces_branch_net,
input_indeces_trunk_net,
aggregator="*",
reduction="+",
scale=True,
translation=True):
def __init__(
self,
branch_net,
trunk_net,
input_indeces_branch_net,
input_indeces_trunk_net,
aggregator="*",
reduction="+",
scale=True,
translation=True,
):
"""
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input a :py:obj:`pina.label_tensor.LabelTensor`
@@ -363,14 +377,15 @@ class DeepONet(MIONet):
"""
networks = {
branch_net: input_indeces_branch_net,
trunk_net: input_indeces_trunk_net
trunk_net: input_indeces_trunk_net,
}
super().__init__(networks=networks,
aggregator=aggregator,
reduction=reduction,
scale=scale,
translation=translation)
super().__init__(
networks=networks,
aggregator=aggregator,
reduction=reduction,
scale=scale,
translation=translation,
)
def forward(self, x):
"""
@@ -378,11 +393,10 @@ class DeepONet(MIONet):
:param LabelTensor or torch.Tensor x: The input tensor for the forward call.
:return: The output computed by the DeepONet model.
:rtype: LabelTensor or torch.Tensor
:rtype: LabelTensor or torch.Tensor
"""
return super().forward(x)
@property
def branch_net(self):
"""