Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -63,11 +63,9 @@ def grad(output_, input_, components=None, d=None):
retain_graph=True,
allow_unused=True,
)[0]
gradients.labels = input_.labels
gradients = gradients.extract(d)
gradients.labels = input_.stored_labels
gradients = gradients[..., [input_.labels.index(i) for i in d]]
gradients.labels = [f"d{output_fieldname}d{i}" for i in d]
return gradients
if not isinstance(input_, LabelTensor):
@@ -216,7 +214,9 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
to_append_tensors = []
for i, label in enumerate(grad_output.labels):
gg = grad(grad_output, input_, d=d, components=[label])
to_append_tensors.append(gg.extract([gg.labels[i]]))
gg = gg.extract([gg.labels[i]])
to_append_tensors.append(gg)
labels = [f"dd{components[0]}"]
result = LabelTensor.summation(tensors=to_append_tensors)
result.labels = labels