🎨 Format Python code with psf/black (#333)
Co-authored-by: ndem0 <ndem0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
eea0cc0833
commit
1aca017e1d
@@ -57,14 +57,12 @@ class OrthogonalBlock(torch.nn.Module):
|
|||||||
v = torch.select(X, self.dim, i).clone()
|
v = torch.select(X, self.dim, i).clone()
|
||||||
for j in range(i):
|
for j in range(i):
|
||||||
vj = torch.select(result, self.dim, j).clone()
|
vj = torch.select(result, self.dim, j).clone()
|
||||||
v = v - torch.sum(v * vj,
|
v = v - torch.sum(v * vj, dim=self.dim, keepdim=True) * vj
|
||||||
dim=self.dim, keepdim=True) * vj
|
|
||||||
# result_i = torch.select(result, self.dim, i)
|
# result_i = torch.select(result, self.dim, i)
|
||||||
result_i = v / torch.linalg.norm(v)
|
result_i = v / torch.linalg.norm(v)
|
||||||
result = self._differentiable_copy(result, i, result_i)
|
result = self._differentiable_copy(result, i, result_i)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _differentiable_copy(self, result, idx, value):
|
def _differentiable_copy(self, result, idx, value):
|
||||||
"""
|
"""
|
||||||
Perform a differentiable copy operation on a tensor.
|
Perform a differentiable copy operation on a tensor.
|
||||||
@@ -104,8 +102,10 @@ class OrthogonalBlock(torch.nn.Module):
|
|||||||
# check consistency
|
# check consistency
|
||||||
check_consistency(value, int)
|
check_consistency(value, int)
|
||||||
if value not in [0, 1, -1]:
|
if value not in [0, 1, -1]:
|
||||||
raise IndexError('Dimension out of range (expected to be in '
|
raise IndexError(
|
||||||
f'range of [-1, 1], but got {value})')
|
"Dimension out of range (expected to be in "
|
||||||
|
f"range of [-1, 1], but got {value})"
|
||||||
|
)
|
||||||
# assign value
|
# assign value
|
||||||
self._dim = value
|
self._dim = value
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user