🎨 Format Python code with psf/black (#452)
This commit is contained in:
committed by
GitHub
parent
c28a4bd2b9
commit
f0bff2438b
@@ -6,6 +6,7 @@ from .stride import Stride
|
|||||||
from .utils_convolution import optimizing
|
from .utils_convolution import optimizing
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class PODBlock(torch.nn.Module):
|
class PODBlock(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
POD layer: it projects the input field on the proper orthogonal
|
POD layer: it projects the input field on the proper orthogonal
|
||||||
@@ -121,12 +122,15 @@ class PODBlock(torch.nn.Module):
|
|||||||
if X.device.type == "mps": # svd_lowrank not arailable for mps
|
if X.device.type == "mps": # svd_lowrank not arailable for mps
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"svd_lowrank not available for mps, using svd instead."
|
"svd_lowrank not available for mps, using svd instead."
|
||||||
"This may slow down computations.", ResourceWarning
|
"This may slow down computations.",
|
||||||
|
ResourceWarning,
|
||||||
)
|
)
|
||||||
self._basis = torch.svd(X.T)[0].T
|
self._basis = torch.svd(X.T)[0].T
|
||||||
else:
|
else:
|
||||||
if randomized:
|
if randomized:
|
||||||
warnings.warn("Considering a randomized algorithm to compute the POD basis")
|
warnings.warn(
|
||||||
|
"Considering a randomized algorithm to compute the POD basis"
|
||||||
|
)
|
||||||
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
|
||||||
else:
|
else:
|
||||||
self._basis = torch.svd(X.T)[0].T
|
self._basis = torch.svd(X.T)[0].T
|
||||||
|
|||||||
Reference in New Issue
Block a user