Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver

This commit is contained in:
FilippoOlivo
2024-10-16 11:24:37 +02:00
committed by Nicola Demo
parent b9753c34b2
commit c9304fb9bb
30 changed files with 770 additions and 784 deletions

View File

@@ -3,6 +3,7 @@ from copy import deepcopy, copy
import torch
from torch import Tensor
def issubset(a, b):
"""
Check if a is a subset of b.
@@ -45,7 +46,7 @@ class LabelTensor(torch.Tensor):
:return: labels of self
:rtype: list
"""
return self._labels[self.tensor.ndim-1]['dof']
return self._labels[self.tensor.ndim - 1]['dof']
@property
def full_labels(self):
@@ -103,23 +104,23 @@ class LabelTensor(torch.Tensor):
raise ValueError('labels_to_extract must be str or list or dict')
def _extract_from_list(self, labels_to_extract):
#Store locally all necessary obj/variables
# Store locally all necessary obj/variables
ndim = self.tensor.ndim
labels = self.full_labels
tensor = self.tensor
last_dim_label = self.labels
#Verify if all the labels in labels_to_extract are in last dimension
# Verify if all the labels in labels_to_extract are in last dimension
if set(labels_to_extract).issubset(last_dim_label) is False:
raise ValueError('Cannot extract a dof which is not in the original LabelTensor')
#Extract index to extract
# Extract index to extract
idx_to_extract = [last_dim_label.index(i) for i in labels_to_extract]
#Perform extraction
# Perform extraction
new_tensor = tensor[..., idx_to_extract]
#Manage labels
# Manage labels
new_labels = copy(labels)
last_dim_new_label = {ndim - 1: {
@@ -186,7 +187,7 @@ class LabelTensor(torch.Tensor):
# Perform cat on tensors
new_tensor = torch.cat(tensors, dim=dim)
#Update labels
# Update labels
labels = tensors[0].full_labels
labels.pop(dim)
new_labels_cat_dim = new_labels_cat_dim if len(set(new_labels_cat_dim)) == len(new_labels_cat_dim) \
@@ -265,13 +266,13 @@ class LabelTensor(torch.Tensor):
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
"""
tensor_shape = self.tensor.shape
#Check dimensionality
# Check dimensionality
for k, v in labels.items():
if len(v['dof']) != len(set(v['dof'])):
raise ValueError("dof must be unique")
if len(v['dof']) != tensor_shape[k]:
raise ValueError('Number of dof does not match with tensor dimension')
#Perform update
# Perform update
self._labels.update(labels)
def update_labels_from_list(self, labels):
@@ -310,7 +311,7 @@ class LabelTensor(torch.Tensor):
if mode == 'std':
# Call cat on last dimension
new_label_tensor = LabelTensor.cat([self, tensor], dim=self.tensor.ndim - 1)
elif mode=='cross':
elif mode == 'cross':
# Crete tensor and call cat on last dimension
tensor1 = self
tensor2 = tensor
@@ -318,7 +319,7 @@ class LabelTensor(torch.Tensor):
n2 = tensor2.shape[0]
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels)
new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim-1)
new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim - 1)
else:
raise ValueError('mode must be either "std" or "cross"')
return new_label_tensor
@@ -366,10 +367,10 @@ class LabelTensor(torch.Tensor):
if hasattr(self, "labels"):
if isinstance(index[j], list):
new_labels.update({j: {'dof': [new_labels[j]['dof'][i] for i in index[1]],
'name':new_labels[j]['name']}})
'name': new_labels[j]['name']}})
else:
new_labels.update({j: {'dof': new_labels[j]['dof'][index[j]],
'name':new_labels[j]['name']}})
'name': new_labels[j]['name']}})
selected_lt.labels = new_labels
else:
@@ -382,12 +383,13 @@ class LabelTensor(torch.Tensor):
def sort_labels(self, dim=None):
def argsort(lst):
return sorted(range(len(lst)), key=lambda x: lst[x])
if dim is None:
dim = self.tensor.ndim-1
dim = self.tensor.ndim - 1
labels = self.full_labels[dim]['dof']
sorted_index = argsort(labels)
indexer = [slice(None)] * self.tensor.ndim
indexer[dim] = sorted_index
new_labels = deepcopy(self.full_labels)
new_labels[dim] = {'dof': sorted(labels), 'name': new_labels[dim]['name']}
return LabelTensor(self.tensor[indexer], new_labels)
return LabelTensor(self.tensor[indexer], new_labels)