From f812d87727d602f23ad1f57b234438efafe51816 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 27 Sep 2024 10:35:37 +0200 Subject: [PATCH] :art: Format Python code with psf/black (#348) --- pina/model/__init__.py | 2 +- pina/model/spline.py | 118 ++++++++++++++++++++--------------------- 2 files changed, 60 insertions(+), 60 deletions(-) diff --git a/pina/model/__init__.py b/pina/model/__init__.py index fdecfb9..3224d0a 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -19,4 +19,4 @@ from .fno import FNO, FourierIntegralKernel from .base_no import KernelNeuralOperator from .avno import AveragingNeuralOperator from .lno import LowRankNeuralOperator -from .spline import Spline \ No newline at end of file +from .spline import Spline diff --git a/pina/model/spline.py b/pina/model/spline.py index 0b57815..2c5aa6e 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn from ..utils import check_consistency - + + class Spline(torch.nn.Module): def __init__(self, order=4, knots=None, control_points=None) -> None: @@ -31,38 +32,37 @@ class Spline(torch.nn.Module): self.control_points = control_points elif knots is not None: - print('Warning: control points will be initialized automatically.') - print(' experimental feature') + print("Warning: control points will be initialized automatically.") + print(" experimental feature") self.knots = knots n = len(knots) - order self.control_points = torch.nn.Parameter( - torch.zeros(n), requires_grad=True) - - elif control_points is not None: - print('Warning: knots will be initialized automatically.') - print(' experimental feature') - - self.control_points = control_points - - n = len(self.control_points)-1 - self.knots = { - 'type': 'auto', - 'min': 0, - 'max': 1, - 'n': n+2+self.order} - - else: - raise ValueError( - "Knots and control points cannot be both None." + torch.zeros(n), requires_grad=True ) + elif control_points is not None: + print("Warning: knots will be initialized automatically.") + print(" experimental feature") + + self.control_points = control_points + + n = len(self.control_points) - 1 + self.knots = { + "type": "auto", + "min": 0, + "max": 1, + "n": n + 2 + self.order, + } + + else: + raise ValueError("Knots and control points cannot be both None.") if self.knots.ndim != 1: raise ValueError("Knot vector must be one-dimensional.") def basis(self, x, k, i, t): - ''' + """ Recursive function to compute the basis functions of the spline. :param torch.Tensor x: points to be evaluated. @@ -71,28 +71,32 @@ class Spline(torch.nn.Module): :param torch.Tensor t: vector of knots :return: the basis functions evaluated at x :rtype: torch.Tensor - ''' - + """ + if k == 0: - a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0) + a = torch.where( + torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0 + ) if i == len(t) - self.order - 1: - a = torch.where(x == t[-1], 1.0, a) + a = torch.where(x == t[-1], 1.0, a) a.requires_grad_(True) return a - - if t[i+k] == t[i]: - c1 = torch.tensor([0.0]*len(x), requires_grad=True) + if t[i + k] == t[i]: + c1 = torch.tensor([0.0] * len(x), requires_grad=True) else: - c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t) + c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t) - if t[i+k+1] == t[i+1]: - c2 = torch.tensor([0.0]*len(x), requires_grad=True) + if t[i + k + 1] == t[i + 1]: + c2 = torch.tensor([0.0] * len(x), requires_grad=True) else: - c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t) + c2 = ( + (t[i + k + 1] - x) + / (t[i + k + 1] - t[i + 1]) + * self.basis(x, k - 1, i + 1, t) + ) return c1 + c2 - @property def control_points(self): @@ -101,50 +105,46 @@ class Spline(torch.nn.Module): @control_points.setter def control_points(self, value): if isinstance(value, dict): - if 'n' not in value: - raise ValueError('Invalid value for control_points') - n = value['n'] - dim = value.get('dim', 1) + if "n" not in value: + raise ValueError("Invalid value for control_points") + n = value["n"] + dim = value.get("dim", 1) value = torch.zeros(n, dim) if not isinstance(value, torch.Tensor): - raise ValueError('Invalid value for control_points') + raise ValueError("Invalid value for control_points") self._control_points = torch.nn.Parameter(value, requires_grad=True) @property def knots(self): return self._knots - + @knots.setter def knots(self, value): if isinstance(value, dict): - type_ = value.get('type', 'auto') - min_ = value.get('min', 0) - max_ = value.get('max', 1) - n = value.get('n', 10) + type_ = value.get("type", "auto") + min_ = value.get("min", 0) + max_ = value.get("max", 1) + n = value.get("n", 10) - if type_ == 'uniform': + if type_ == "uniform": value = torch.linspace(min_, max_, n + self.k + 1) - elif type_ == 'auto': - initial_knots = torch.ones(self.order+1)*min_ - final_knots = torch.ones(self.order+1)*max_ + elif type_ == "auto": + initial_knots = torch.ones(self.order + 1) * min_ + final_knots = torch.ones(self.order + 1) * max_ if n < self.order + 1: value = torch.concatenate((initial_knots, final_knots)) - elif n - 2*self.order + 1 == 1: - value = torch.Tensor([(max_ + min_)/2]) + elif n - 2 * self.order + 1 == 1: + value = torch.Tensor([(max_ + min_) / 2]) else: - value = torch.linspace(min_, max_, n - 2*self.order - 1) + value = torch.linspace(min_, max_, n - 2 * self.order - 1) - value = torch.concatenate( - ( - initial_knots, value, final_knots - ) - ) + value = torch.concatenate((initial_knots, value, final_knots)) if not isinstance(value, torch.Tensor): - raise ValueError('Invalid value for knots') + raise ValueError("Invalid value for knots") self._knots = value @@ -154,7 +154,7 @@ class Spline(torch.nn.Module): :param torch.Tensor x_: points to be evaluated. :return: the spline evaluated at x_ - :rtype: torch.Tensor + :rtype: torch.Tensor """ t = self.knots k = self.k @@ -163,4 +163,4 @@ class Spline(torch.nn.Module): basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c))) y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) - return y \ No newline at end of file + return y