Generic DeepONet (#68)

* generic deeponet
* example for generic deeponet
* adapt tests to new interface
This commit is contained in:
Francesco Andreuzzi
2023-01-11 12:07:19 +01:00
committed by GitHub
parent e227700fbc
commit 7ce080fd85
5 changed files with 280 additions and 37 deletions

View File

@@ -26,9 +26,11 @@ class FeedForward(torch.nn.Module):
`inner_size` are not considered.
:param iterable(torch.nn.Module) extra_features: the additional input
features to use ad augmented input.
:param bool bias: If `True` the MLP will consider some bias.
"""
def __init__(self, input_variables, output_variables, inner_size=20,
n_layers=2, func=nn.Tanh, layers=None, extra_features=None):
n_layers=2, func=nn.Tanh, layers=None, extra_features=None,
bias=True):
"""
"""
super().__init__()
@@ -62,7 +64,9 @@ class FeedForward(torch.nn.Module):
self.layers = []
for i in range(len(tmp_layers)-1):
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1]))
self.layers.append(
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
)
if isinstance(func, list):
self.functions = func
@@ -94,7 +98,7 @@ class FeedForward(torch.nn.Module):
if self.input_variables:
x = x.extract(self.input_variables)
for i, feature in enumerate(self.extra_features):
for feature in self.extra_features:
x = x.append(feature(x))
output = self.model(x).as_subclass(LabelTensor)