Neural Operator fix and addition

* Building FNO for 1D/2D/3D data
* Fixing bug in trunk/branch net in DeepONEt
* Fixing type check bug in spectral conv
* Adding tests for FNO
* Fixing bug in Fourier Layer (conv1d/2d/3d)
This commit is contained in:
Dario Coscia
2023-09-09 22:09:34 +02:00
committed by Nicola Demo
parent 83ecdb0eab
commit 603f56d264
6 changed files with 315 additions and 33 deletions

View File

@@ -115,15 +115,19 @@ class SpectralConvBlock2D(nn.Module):
# check type consistency
check_consistency(input_numb_fields, int)
check_consistency(output_numb_fields, int)
if not isinstance(n_modes, (tuple, list)):
raise ValueError('expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension ')
if len(n_modes) != 2:
raise ValueError('expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension ')
check_consistency(n_modes, int)
if isinstance(n_modes, (tuple, list)):
if len(n_modes) != 2:
raise ValueError('Expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension ')
elif isinstance(n_modes, int):
n_modes = [n_modes]*2
else:
raise ValueError('Expected n_modes to be a list or tuple of len two, '
'with each entry corresponding to the number of modes '
'for each dimension; or an int value representing the '
'number of modes for all dimensions')
# assign variables
@@ -222,15 +226,19 @@ class SpectralConvBlock3D(nn.Module):
# check type consistency
check_consistency(input_numb_fields, int)
check_consistency(output_numb_fields, int)
if not isinstance(n_modes, (tuple, list)):
raise ValueError('expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension ')
if len(n_modes) != 3:
raise ValueError('expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension ')
check_consistency(n_modes, int)
if isinstance(n_modes, (tuple, list)):
if len(n_modes) != 3:
raise ValueError('Expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension ')
elif isinstance(n_modes, int):
n_modes = [n_modes]*3
else:
raise ValueError('Expected n_modes to be a list or tuple of len three, '
'with each entry corresponding to the number of modes '
'for each dimension; or an int value representing the '
'number of modes for all dimensions')
# assign variables
self._modes = n_modes