Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -119,6 +119,7 @@ class LowRankBlock(torch.nn.Module):
:rtype: torch.Tensor
"""
# extract basis
coords = coords.as_subclass(torch.Tensor)
basis = self._basis(coords)
# reshape [B, N, D, 2*rank]
shape = list(basis.shape[:-1]) + [-1, 2 * self.rank]