fix bug network
This commit is contained in:
committed by
Nicola Demo
parent
ee39b39805
commit
a9f14ac323
@@ -2,6 +2,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency
|
||||
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||
from pina import LabelTensor
|
||||
import warnings
|
||||
|
||||
|
||||
class FNO(torch.nn.Module):
|
||||
@@ -69,7 +71,7 @@ class FNO(torch.nn.Module):
|
||||
elif dimensions == 3:
|
||||
fourier_layer = FourierBlock3D
|
||||
else:
|
||||
NotImplementedError('FNO implemented only for 1D/2D/3D data.')
|
||||
raise NotImplementedError('FNO implemented only for 1D/2D/3D data.')
|
||||
|
||||
# Here we build the FNO by stacking Fourier Blocks
|
||||
|
||||
@@ -137,6 +139,9 @@ class FNO(torch.nn.Module):
|
||||
:return: The output tensor obtained from the FNO.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if isinstance(x, LabelTensor): #TODO remove when Network is fixed
|
||||
warnings.warn('LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor')
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
|
||||
# lifting the input in higher dimensional space
|
||||
x = self._lifting_net(x)
|
||||
|
||||
Reference in New Issue
Block a user