fix bug network

This commit is contained in:
Dario Coscia
2023-11-13 12:25:40 +01:00
committed by Nicola Demo
parent ee39b39805
commit a9f14ac323
6 changed files with 127 additions and 80 deletions

View File

@@ -2,6 +2,8 @@ import torch
import torch.nn as nn
from ..utils import check_consistency
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from pina import LabelTensor
import warnings
class FNO(torch.nn.Module):
@@ -69,7 +71,7 @@ class FNO(torch.nn.Module):
elif dimensions == 3:
fourier_layer = FourierBlock3D
else:
NotImplementedError('FNO implemented only for 1D/2D/3D data.')
raise NotImplementedError('FNO implemented only for 1D/2D/3D data.')
# Here we build the FNO by stacking Fourier Blocks
@@ -137,6 +139,9 @@ class FNO(torch.nn.Module):
:return: The output tensor obtained from the FNO.
:rtype: torch.Tensor
"""
if isinstance(x, LabelTensor): #TODO remove when Network is fixed
warnings.warn('LabelTensor passed as input is not allowed, casting LabelTensor to Torch.Tensor')
x = x.as_subclass(torch.Tensor)
# lifting the input in higher dimensional space
x = self._lifting_net(x)