Format Python code with psf/black push (#325)
* 🎨 Format Python code with psf/black
This commit is contained in:
committed by
GitHub
parent
cce9876751
commit
5445559cb2
@@ -2,8 +2,8 @@ __all__ = [
|
|||||||
"SwitchOptimizer",
|
"SwitchOptimizer",
|
||||||
"R3Refinement",
|
"R3Refinement",
|
||||||
"MetricTracker",
|
"MetricTracker",
|
||||||
"PINAProgressBar"
|
"PINAProgressBar",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .optimizer_callbacks import SwitchOptimizer
|
from .optimizer_callbacks import SwitchOptimizer
|
||||||
from .adaptive_refinment_callbacks import R3Refinement
|
from .adaptive_refinment_callbacks import R3Refinement
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import torch
|
|||||||
import copy
|
import copy
|
||||||
|
|
||||||
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
|
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
|
||||||
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
|
from lightning.pytorch.callbacks.progress.progress_bar import (
|
||||||
|
get_standard_metrics,
|
||||||
|
)
|
||||||
from pina.utils import check_consistency
|
from pina.utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
class MetricTracker(Callback):
|
class MetricTracker(Callback):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -37,7 +40,7 @@ class MetricTracker(Callback):
|
|||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
"""
|
"""
|
||||||
Collect and track metrics at the end of each training epoch.
|
Collect and track metrics at the end of each training epoch.
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
:param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
:type trainer: pytorch_lightning.Trainer
|
||||||
@@ -68,7 +71,8 @@ class MetricTracker(Callback):
|
|||||||
class PINAProgressBar(TQDMProgressBar):
|
class PINAProgressBar(TQDMProgressBar):
|
||||||
|
|
||||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||||
def __init__(self, metrics='mean', **kwargs):
|
|
||||||
|
def __init__(self, metrics="mean", **kwargs):
|
||||||
"""
|
"""
|
||||||
PINA Implementation of a Lightning Callback for enriching the progress
|
PINA Implementation of a Lightning Callback for enriching the progress
|
||||||
bar.
|
bar.
|
||||||
@@ -123,7 +127,7 @@ class PINAProgressBar(TQDMProgressBar):
|
|||||||
if pbar_metrics:
|
if pbar_metrics:
|
||||||
pbar_metrics = {
|
pbar_metrics = {
|
||||||
key: pbar_metrics[key] for key in self._sorted_metrics
|
key: pbar_metrics[key] for key in self._sorted_metrics
|
||||||
}
|
}
|
||||||
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
|
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
|
||||||
if duplicates:
|
if duplicates:
|
||||||
rank_zero_warn(
|
rank_zero_warn(
|
||||||
@@ -133,7 +137,7 @@ class PINAProgressBar(TQDMProgressBar):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {**standard_metrics, **pbar_metrics}
|
return {**standard_metrics, **pbar_metrics}
|
||||||
|
|
||||||
def on_fit_start(self, trainer, pl_module):
|
def on_fit_start(self, trainer, pl_module):
|
||||||
"""
|
"""
|
||||||
Check that the metrics defined in the initialization are available,
|
Check that the metrics defined in the initialization are available,
|
||||||
@@ -145,9 +149,13 @@ class PINAProgressBar(TQDMProgressBar):
|
|||||||
"""
|
"""
|
||||||
# Check if all keys in sort_keys are present in the dictionary
|
# Check if all keys in sort_keys are present in the dictionary
|
||||||
for key in self._sorted_metrics:
|
for key in self._sorted_metrics:
|
||||||
if key not in trainer.solver.problem.conditions.keys() and key != 'mean':
|
if (
|
||||||
|
key not in trainer.solver.problem.conditions.keys()
|
||||||
|
and key != "mean"
|
||||||
|
):
|
||||||
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
||||||
# add the loss pedix
|
# add the loss pedix
|
||||||
self._sorted_metrics = [
|
self._sorted_metrics = [
|
||||||
metric + '_loss' for metric in self._sorted_metrics]
|
metric + "_loss" for metric in self._sorted_metrics
|
||||||
return super().on_fit_start(trainer, pl_module)
|
]
|
||||||
|
return super().on_fit_start(trainer, pl_module)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ __all__ = [
|
|||||||
"FourierFeatureEmbedding",
|
"FourierFeatureEmbedding",
|
||||||
"AVNOBlock",
|
"AVNOBlock",
|
||||||
"LowRankBlock",
|
"LowRankBlock",
|
||||||
"RBFBlock"
|
"RBFBlock",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
|
|||||||
@@ -6,65 +6,74 @@ from itertools import combinations_with_replacement
|
|||||||
import torch
|
import torch
|
||||||
from ...utils import check_consistency
|
from ...utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
def linear(r):
|
def linear(r):
|
||||||
'''
|
"""
|
||||||
Linear radial basis function.
|
Linear radial basis function.
|
||||||
'''
|
"""
|
||||||
return -r
|
return -r
|
||||||
|
|
||||||
|
|
||||||
def thin_plate_spline(r, eps=1e-7):
|
def thin_plate_spline(r, eps=1e-7):
|
||||||
'''
|
"""
|
||||||
Thin plate spline radial basis function.
|
Thin plate spline radial basis function.
|
||||||
'''
|
"""
|
||||||
r = torch.clamp(r, min=eps)
|
r = torch.clamp(r, min=eps)
|
||||||
return r**2 * torch.log(r)
|
return r**2 * torch.log(r)
|
||||||
|
|
||||||
|
|
||||||
def cubic(r):
|
def cubic(r):
|
||||||
'''
|
"""
|
||||||
Cubic radial basis function.
|
Cubic radial basis function.
|
||||||
'''
|
"""
|
||||||
return r**3
|
return r**3
|
||||||
|
|
||||||
|
|
||||||
def quintic(r):
|
def quintic(r):
|
||||||
'''
|
"""
|
||||||
Quintic radial basis function.
|
Quintic radial basis function.
|
||||||
'''
|
"""
|
||||||
return -r**5
|
return -(r**5)
|
||||||
|
|
||||||
|
|
||||||
def multiquadric(r):
|
def multiquadric(r):
|
||||||
'''
|
"""
|
||||||
Multiquadric radial basis function.
|
Multiquadric radial basis function.
|
||||||
'''
|
"""
|
||||||
return -torch.sqrt(r**2 + 1)
|
return -torch.sqrt(r**2 + 1)
|
||||||
|
|
||||||
|
|
||||||
def inverse_multiquadric(r):
|
def inverse_multiquadric(r):
|
||||||
'''
|
"""
|
||||||
Inverse multiquadric radial basis function.
|
Inverse multiquadric radial basis function.
|
||||||
'''
|
"""
|
||||||
return 1/torch.sqrt(r**2 + 1)
|
return 1 / torch.sqrt(r**2 + 1)
|
||||||
|
|
||||||
|
|
||||||
def inverse_quadratic(r):
|
def inverse_quadratic(r):
|
||||||
'''
|
"""
|
||||||
Inverse quadratic radial basis function.
|
Inverse quadratic radial basis function.
|
||||||
'''
|
"""
|
||||||
return 1/(r**2 + 1)
|
return 1 / (r**2 + 1)
|
||||||
|
|
||||||
|
|
||||||
def gaussian(r):
|
def gaussian(r):
|
||||||
'''
|
"""
|
||||||
Gaussian radial basis function.
|
Gaussian radial basis function.
|
||||||
'''
|
"""
|
||||||
return torch.exp(-r**2)
|
return torch.exp(-(r**2))
|
||||||
|
|
||||||
|
|
||||||
radial_functions = {
|
radial_functions = {
|
||||||
"linear": linear,
|
"linear": linear,
|
||||||
"thin_plate_spline": thin_plate_spline,
|
"thin_plate_spline": thin_plate_spline,
|
||||||
"cubic": cubic,
|
"cubic": cubic,
|
||||||
"quintic": quintic,
|
"quintic": quintic,
|
||||||
"multiquadric": multiquadric,
|
"multiquadric": multiquadric,
|
||||||
"inverse_multiquadric": inverse_multiquadric,
|
"inverse_multiquadric": inverse_multiquadric,
|
||||||
"inverse_quadratic": inverse_quadratic,
|
"inverse_quadratic": inverse_quadratic,
|
||||||
"gaussian": gaussian
|
"gaussian": gaussian,
|
||||||
}
|
}
|
||||||
|
|
||||||
scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}
|
scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}
|
||||||
|
|
||||||
@@ -73,8 +82,8 @@ min_degree_funcs = {
|
|||||||
"linear": 0,
|
"linear": 0,
|
||||||
"thin_plate_spline": 1,
|
"thin_plate_spline": 1,
|
||||||
"cubic": 1,
|
"cubic": 1,
|
||||||
"quintic": 2
|
"quintic": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RBFBlock(torch.nn.Module):
|
class RBFBlock(torch.nn.Module):
|
||||||
@@ -88,6 +97,7 @@ class RBFBlock(torch.nn.Module):
|
|||||||
it is inspired from the implementation in `torchrbf.
|
it is inspired from the implementation in `torchrbf.
|
||||||
<https://github.com/ArmanMaesumi/torchrbf>`_
|
<https://github.com/ArmanMaesumi/torchrbf>`_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
neighbors=None,
|
neighbors=None,
|
||||||
@@ -207,7 +217,8 @@ class RBFBlock(torch.nn.Module):
|
|||||||
if value < min_degree:
|
if value < min_degree:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`degree` is too small for this kernel. Setting to "
|
"`degree` is too small for this kernel. Setting to "
|
||||||
f"{min_degree}.", UserWarning,
|
f"{min_degree}.",
|
||||||
|
UserWarning,
|
||||||
)
|
)
|
||||||
self._degree = value
|
self._degree = value
|
||||||
|
|
||||||
@@ -222,8 +233,9 @@ class RBFBlock(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.smoothing, (int, float)):
|
if isinstance(self.smoothing, (int, float)):
|
||||||
self.smoothing = torch.full((y.shape[0],), self.smoothing
|
self.smoothing = (
|
||||||
).float().to(y.device)
|
torch.full((y.shape[0],), self.smoothing).float().to(y.device)
|
||||||
|
)
|
||||||
|
|
||||||
def fit(self, y, d):
|
def fit(self, y, d):
|
||||||
"""
|
"""
|
||||||
@@ -243,15 +255,22 @@ class RBFBlock(torch.nn.Module):
|
|||||||
raise NotImplementedError("neighbors currently not supported")
|
raise NotImplementedError("neighbors currently not supported")
|
||||||
|
|
||||||
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
|
powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to(
|
||||||
y.device)
|
y.device
|
||||||
|
)
|
||||||
if powers.shape[0] > nobs:
|
if powers.shape[0] > nobs:
|
||||||
raise ValueError("The data is not compatible with the "
|
raise ValueError(
|
||||||
"requested degree.")
|
"The data is not compatible with the requested degree."
|
||||||
|
)
|
||||||
|
|
||||||
if self.neighbors is None:
|
if self.neighbors is None:
|
||||||
self._shift, self._scale, self._coeffs = RBFBlock.solve(self.y,
|
self._shift, self._scale, self._coeffs = RBFBlock.solve(
|
||||||
self.d.reshape((self.y.shape[0], -1)),
|
self.y,
|
||||||
self.smoothing, self.kernel, self.epsilon, powers)
|
self.d.reshape((self.y.shape[0], -1)),
|
||||||
|
self.smoothing,
|
||||||
|
self.kernel,
|
||||||
|
self.epsilon,
|
||||||
|
powers,
|
||||||
|
)
|
||||||
|
|
||||||
self.powers = powers
|
self.powers = powers
|
||||||
|
|
||||||
@@ -411,8 +430,9 @@ class RBFBlock(torch.nn.Module):
|
|||||||
of the interpolator
|
of the interpolator
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lhs, rhs, shift, scale = RBFBlock.build(y, d, smoothing, kernel,
|
lhs, rhs, shift, scale = RBFBlock.build(
|
||||||
epsilon, powers)
|
y, d, smoothing, kernel, epsilon, powers
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
coeffs = torch.linalg.solve(lhs, rhs)
|
coeffs = torch.linalg.solve(lhs, rhs)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|||||||
@@ -196,9 +196,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
|
|
||||||
# check consistency location
|
# check consistency location
|
||||||
locations_to_sample = [
|
locations_to_sample = [
|
||||||
condition for condition in self.conditions
|
condition
|
||||||
if hasattr(self.conditions[condition], 'location')
|
for condition in self.conditions
|
||||||
]
|
if hasattr(self.conditions[condition], "location")
|
||||||
|
]
|
||||||
if locations == "all":
|
if locations == "all":
|
||||||
# only locations that can be sampled
|
# only locations that can be sampled
|
||||||
locations = locations_to_sample
|
locations = locations_to_sample
|
||||||
|
|||||||
Reference in New Issue
Block a user