🎨 Format Python code with psf/black
This commit is contained in:
@@ -62,19 +62,26 @@ class AveragingNeuralOperator(KernelNeuralOperator):
|
||||
# check hidden dimensions match
|
||||
input_lifting_net = next(lifting_net.parameters()).size()[-1]
|
||||
output_lifting_net = lifting_net(
|
||||
torch.rand(size=next(lifting_net.parameters()).size())
|
||||
).shape[-1]
|
||||
projecting_net_input=next(projecting_net.parameters()).size()[-1]
|
||||
torch.rand(size=next(lifting_net.parameters()).size())
|
||||
).shape[-1]
|
||||
projecting_net_input = next(projecting_net.parameters()).size()[-1]
|
||||
|
||||
if len(field_indices)+len(coordinates_indices) != input_lifting_net:
|
||||
raise ValueError('The lifting_net must take as input the '
|
||||
'coordinates vector and the field vector.')
|
||||
|
||||
if output_lifting_net+len(coordinates_indices) != projecting_net_input:
|
||||
raise ValueError('The projecting_net input must be equal to'
|
||||
'the embedding dimension (which is the output) '
|
||||
'of the lifting_net plus the dimension of the '
|
||||
'coordinates, i.e. len(coordinates_indices).')
|
||||
if len(field_indices) + len(coordinates_indices) != input_lifting_net:
|
||||
raise ValueError(
|
||||
"The lifting_net must take as input the "
|
||||
"coordinates vector and the field vector."
|
||||
)
|
||||
|
||||
if (
|
||||
output_lifting_net + len(coordinates_indices)
|
||||
!= projecting_net_input
|
||||
):
|
||||
raise ValueError(
|
||||
"The projecting_net input must be equal to"
|
||||
"the embedding dimension (which is the output) "
|
||||
"of the lifting_net plus the dimension of the "
|
||||
"coordinates, i.e. len(coordinates_indices)."
|
||||
)
|
||||
|
||||
# assign
|
||||
self.coordinates_indices = coordinates_indices
|
||||
@@ -108,4 +115,4 @@ class AveragingNeuralOperator(KernelNeuralOperator):
|
||||
new_batch = self._integral_kernels(new_batch)
|
||||
new_batch = concatenate((new_batch, points_tmp), dim=2)
|
||||
new_batch = self._projection_operator(new_batch)
|
||||
return new_batch
|
||||
return new_batch
|
||||
|
||||
Reference in New Issue
Block a user