fix bug network
This commit is contained in:
committed by
Nicola Demo
parent
ee39b39805
commit
a9f14ac323
@@ -2,6 +2,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency
|
||||
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||
from pina import LabelTensor
|
||||
import warnings
|
||||
|
||||
|
||||
class FNO(torch.nn.Module):
|
||||
@@ -69,7 +71,7 @@ class FNO(torch.nn.Module):
|
||||
elif dimensions == 3:
|
||||
fourier_layer = FourierBlock3D
|
||||
else:
|
||||
NotImplementedError('FNO implemented only for 1D/2D/3D data.')
|
||||
raise NotImplementedError('FNO implemented only for 1D/2D/3D data.')
|
||||
|
||||
# Here we build the FNO by stacking Fourier Blocks
|
||||
|
||||
@@ -137,6 +139,9 @@ class FNO(torch.nn.Module):
|
||||
:return: The output tensor obtained from the FNO.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if isinstance(x, LabelTensor): #TODO remove when Network is fixed
|
||||
warnings.warn('LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor')
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
|
||||
# lifting the input in higher dimensional space
|
||||
x = self._lifting_net(x)
|
||||
|
||||
@@ -56,6 +56,9 @@ class Network(torch.nn.Module):
|
||||
:param torch.Tensor x: Input of the network.
|
||||
:return torch.Tensor: Output of the network.
|
||||
"""
|
||||
# only labeltensors as input
|
||||
assert isinstance(x, LabelTensor), "Expected LabelTensor as input to the model."
|
||||
|
||||
# extract torch.Tensor from corresponding label
|
||||
# in case `input_variables = []` all points are used
|
||||
if self._input_variables:
|
||||
@@ -65,22 +68,20 @@ class Network(torch.nn.Module):
|
||||
for feature in self._extra_features:
|
||||
x = x.append(feature(x))
|
||||
|
||||
# convert LabelTensor to torch.Tensor
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
|
||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||
# perform forward pass + converting to LabelTensor
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self._output_variables
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# TODO to remove in next releases (only used in GAROM solver)
|
||||
def forward_map(self, x):
|
||||
"""
|
||||
Forward method for Network class when the input is
|
||||
a tuple. This class implements the standard forward method,
|
||||
and it adds the possibility to pass extra features.
|
||||
a tuple. This class is simply a forward with the input casted as a
|
||||
tuple or list :class`torch.Tensor`.
|
||||
All the PINA models ``forward`` s are overriden
|
||||
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
|
||||
extraction.
|
||||
|
||||
Reference in New Issue
Block a user