diff --git a/pina/model/layers/orthogonal.py b/pina/model/layers/orthogonal.py index 1d7cddd..32a0607 100644 --- a/pina/model/layers/orthogonal.py +++ b/pina/model/layers/orthogonal.py @@ -49,22 +49,20 @@ class OrthogonalBlock(torch.nn.Module): result = torch.zeros_like(X, requires_grad=self._requires_grad) X_0 = torch.select(X, self.dim, 0).clone() - result_0 = X_0/torch.linalg.norm(X_0) + result_0 = X_0 / torch.linalg.norm(X_0) result = self._differentiable_copy(result, 0, result_0) # iterate over the rest of the basis with Gram-Schmidt for i in range(1, X.shape[self.dim]): v = torch.select(X, self.dim, i).clone() for j in range(i): - vj = torch.select(result,self.dim,j).clone() - v = v - torch.sum(v * vj, - dim=self.dim, keepdim=True) * vj - #result_i = torch.select(result, self.dim, i) - result_i = v/torch.linalg.norm(v) + vj = torch.select(result, self.dim, j).clone() + v = v - torch.sum(v * vj, dim=self.dim, keepdim=True) * vj + # result_i = torch.select(result, self.dim, i) + result_i = v / torch.linalg.norm(v) result = self._differentiable_copy(result, i, result_i) return result - def _differentiable_copy(self, result, idx, value): """ Perform a differentiable copy operation on a tensor. @@ -79,7 +77,7 @@ class OrthogonalBlock(torch.nn.Module): """ return result.index_copy( self.dim, torch.tensor([idx]), value.unsqueeze(self.dim) - ) + ) @property def dim(self): @@ -104,8 +102,10 @@ class OrthogonalBlock(torch.nn.Module): # check consistency check_consistency(value, int) if value not in [0, 1, -1]: - raise IndexError('Dimension out of range (expected to be in ' - f'range of [-1, 1], but got {value})') + raise IndexError( + "Dimension out of range (expected to be in " + f"range of [-1, 1], but got {value})" + ) # assign value self._dim = value