🎨 Format Python code with psf/black (#452)

This commit is contained in:
github-actions[bot]
2025-02-15 15:45:42 +01:00
committed by GitHub
parent c28a4bd2b9
commit f0bff2438b

View File

@@ -6,6 +6,7 @@ from .stride import Stride
from .utils_convolution import optimizing from .utils_convolution import optimizing
import warnings import warnings
class PODBlock(torch.nn.Module): class PODBlock(torch.nn.Module):
""" """
POD layer: it projects the input field on the proper orthogonal POD layer: it projects the input field on the proper orthogonal
@@ -121,12 +122,15 @@ class PODBlock(torch.nn.Module):
if X.device.type == "mps": # svd_lowrank not arailable for mps if X.device.type == "mps": # svd_lowrank not arailable for mps
warnings.warn( warnings.warn(
"svd_lowrank not available for mps, using svd instead." "svd_lowrank not available for mps, using svd instead."
"This may slow down computations.", ResourceWarning "This may slow down computations.",
ResourceWarning,
) )
self._basis = torch.svd(X.T)[0].T self._basis = torch.svd(X.T)[0].T
else: else:
if randomized: if randomized:
warnings.warn("Considering a randomized algorithm to compute the POD basis") warnings.warn(
"Considering a randomized algorithm to compute the POD basis"
)
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
else: else:
self._basis = torch.svd(X.T)[0].T self._basis = torch.svd(X.T)[0].T