Generic DeepONet (#68)
* generic deeponet * example for generic deeponet * adapt tests to new interface
This commit is contained in:
committed by
GitHub
parent
e227700fbc
commit
7ce080fd85
@@ -26,9 +26,11 @@ class FeedForward(torch.nn.Module):
|
||||
`inner_size` are not considered.
|
||||
:param iterable(torch.nn.Module) extra_features: the additional input
|
||||
features to use ad augmented input.
|
||||
:param bool bias: If `True` the MLP will consider some bias.
|
||||
"""
|
||||
def __init__(self, input_variables, output_variables, inner_size=20,
|
||||
n_layers=2, func=nn.Tanh, layers=None, extra_features=None):
|
||||
n_layers=2, func=nn.Tanh, layers=None, extra_features=None,
|
||||
bias=True):
|
||||
"""
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -62,7 +64,9 @@ class FeedForward(torch.nn.Module):
|
||||
|
||||
self.layers = []
|
||||
for i in range(len(tmp_layers)-1):
|
||||
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1]))
|
||||
self.layers.append(
|
||||
nn.Linear(tmp_layers[i], tmp_layers[i + 1], bias=bias)
|
||||
)
|
||||
|
||||
if isinstance(func, list):
|
||||
self.functions = func
|
||||
@@ -94,7 +98,7 @@ class FeedForward(torch.nn.Module):
|
||||
if self.input_variables:
|
||||
x = x.extract(self.input_variables)
|
||||
|
||||
for i, feature in enumerate(self.extra_features):
|
||||
for feature in self.extra_features:
|
||||
x = x.append(feature(x))
|
||||
|
||||
output = self.model(x).as_subclass(LabelTensor)
|
||||
|
||||
Reference in New Issue
Block a user