🎨 Format Python code with psf/black (#333)

Co-authored-by: ndem0 <ndem0@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2024-09-03 16:53:39 +02:00
committed by GitHub
parent eea0cc0833
commit 1aca017e1d

View File

@@ -49,22 +49,20 @@ class OrthogonalBlock(torch.nn.Module):
result = torch.zeros_like(X, requires_grad=self._requires_grad) result = torch.zeros_like(X, requires_grad=self._requires_grad)
X_0 = torch.select(X, self.dim, 0).clone() X_0 = torch.select(X, self.dim, 0).clone()
result_0 = X_0/torch.linalg.norm(X_0) result_0 = X_0 / torch.linalg.norm(X_0)
result = self._differentiable_copy(result, 0, result_0) result = self._differentiable_copy(result, 0, result_0)
# iterate over the rest of the basis with Gram-Schmidt # iterate over the rest of the basis with Gram-Schmidt
for i in range(1, X.shape[self.dim]): for i in range(1, X.shape[self.dim]):
v = torch.select(X, self.dim, i).clone() v = torch.select(X, self.dim, i).clone()
for j in range(i): for j in range(i):
vj = torch.select(result,self.dim,j).clone() vj = torch.select(result, self.dim, j).clone()
v = v - torch.sum(v * vj, v = v - torch.sum(v * vj, dim=self.dim, keepdim=True) * vj
dim=self.dim, keepdim=True) * vj # result_i = torch.select(result, self.dim, i)
#result_i = torch.select(result, self.dim, i) result_i = v / torch.linalg.norm(v)
result_i = v/torch.linalg.norm(v)
result = self._differentiable_copy(result, i, result_i) result = self._differentiable_copy(result, i, result_i)
return result return result
def _differentiable_copy(self, result, idx, value): def _differentiable_copy(self, result, idx, value):
""" """
Perform a differentiable copy operation on a tensor. Perform a differentiable copy operation on a tensor.
@@ -79,7 +77,7 @@ class OrthogonalBlock(torch.nn.Module):
""" """
return result.index_copy( return result.index_copy(
self.dim, torch.tensor([idx]), value.unsqueeze(self.dim) self.dim, torch.tensor([idx]), value.unsqueeze(self.dim)
) )
@property @property
def dim(self): def dim(self):
@@ -104,8 +102,10 @@ class OrthogonalBlock(torch.nn.Module):
# check consistency # check consistency
check_consistency(value, int) check_consistency(value, int)
if value not in [0, 1, -1]: if value not in [0, 1, -1]:
raise IndexError('Dimension out of range (expected to be in ' raise IndexError(
f'range of [-1, 1], but got {value})') "Dimension out of range (expected to be in "
f"range of [-1, 1], but got {value})"
)
# assign value # assign value
self._dim = value self._dim = value