Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver
This commit is contained in:
committed by
Nicola Demo
parent
b9753c34b2
commit
c9304fb9bb
@@ -3,6 +3,7 @@ from copy import deepcopy, copy
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def issubset(a, b):
|
||||
"""
|
||||
Check if a is a subset of b.
|
||||
@@ -45,7 +46,7 @@ class LabelTensor(torch.Tensor):
|
||||
:return: labels of self
|
||||
:rtype: list
|
||||
"""
|
||||
return self._labels[self.tensor.ndim-1]['dof']
|
||||
return self._labels[self.tensor.ndim - 1]['dof']
|
||||
|
||||
@property
|
||||
def full_labels(self):
|
||||
@@ -103,23 +104,23 @@ class LabelTensor(torch.Tensor):
|
||||
raise ValueError('labels_to_extract must be str or list or dict')
|
||||
|
||||
def _extract_from_list(self, labels_to_extract):
|
||||
#Store locally all necessary obj/variables
|
||||
# Store locally all necessary obj/variables
|
||||
ndim = self.tensor.ndim
|
||||
labels = self.full_labels
|
||||
tensor = self.tensor
|
||||
last_dim_label = self.labels
|
||||
|
||||
#Verify if all the labels in labels_to_extract are in last dimension
|
||||
# Verify if all the labels in labels_to_extract are in last dimension
|
||||
if set(labels_to_extract).issubset(last_dim_label) is False:
|
||||
raise ValueError('Cannot extract a dof which is not in the original LabelTensor')
|
||||
|
||||
#Extract index to extract
|
||||
# Extract index to extract
|
||||
idx_to_extract = [last_dim_label.index(i) for i in labels_to_extract]
|
||||
|
||||
#Perform extraction
|
||||
# Perform extraction
|
||||
new_tensor = tensor[..., idx_to_extract]
|
||||
|
||||
#Manage labels
|
||||
# Manage labels
|
||||
new_labels = copy(labels)
|
||||
|
||||
last_dim_new_label = {ndim - 1: {
|
||||
@@ -186,7 +187,7 @@ class LabelTensor(torch.Tensor):
|
||||
# Perform cat on tensors
|
||||
new_tensor = torch.cat(tensors, dim=dim)
|
||||
|
||||
#Update labels
|
||||
# Update labels
|
||||
labels = tensors[0].full_labels
|
||||
labels.pop(dim)
|
||||
new_labels_cat_dim = new_labels_cat_dim if len(set(new_labels_cat_dim)) == len(new_labels_cat_dim) \
|
||||
@@ -265,13 +266,13 @@ class LabelTensor(torch.Tensor):
|
||||
:raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape
|
||||
"""
|
||||
tensor_shape = self.tensor.shape
|
||||
#Check dimensionality
|
||||
# Check dimensionality
|
||||
for k, v in labels.items():
|
||||
if len(v['dof']) != len(set(v['dof'])):
|
||||
raise ValueError("dof must be unique")
|
||||
if len(v['dof']) != tensor_shape[k]:
|
||||
raise ValueError('Number of dof does not match with tensor dimension')
|
||||
#Perform update
|
||||
# Perform update
|
||||
self._labels.update(labels)
|
||||
|
||||
def update_labels_from_list(self, labels):
|
||||
@@ -310,7 +311,7 @@ class LabelTensor(torch.Tensor):
|
||||
if mode == 'std':
|
||||
# Call cat on last dimension
|
||||
new_label_tensor = LabelTensor.cat([self, tensor], dim=self.tensor.ndim - 1)
|
||||
elif mode=='cross':
|
||||
elif mode == 'cross':
|
||||
# Crete tensor and call cat on last dimension
|
||||
tensor1 = self
|
||||
tensor2 = tensor
|
||||
@@ -318,7 +319,7 @@ class LabelTensor(torch.Tensor):
|
||||
n2 = tensor2.shape[0]
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels)
|
||||
new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim-1)
|
||||
new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim - 1)
|
||||
else:
|
||||
raise ValueError('mode must be either "std" or "cross"')
|
||||
return new_label_tensor
|
||||
@@ -366,10 +367,10 @@ class LabelTensor(torch.Tensor):
|
||||
if hasattr(self, "labels"):
|
||||
if isinstance(index[j], list):
|
||||
new_labels.update({j: {'dof': [new_labels[j]['dof'][i] for i in index[1]],
|
||||
'name':new_labels[j]['name']}})
|
||||
'name': new_labels[j]['name']}})
|
||||
else:
|
||||
new_labels.update({j: {'dof': new_labels[j]['dof'][index[j]],
|
||||
'name':new_labels[j]['name']}})
|
||||
'name': new_labels[j]['name']}})
|
||||
|
||||
selected_lt.labels = new_labels
|
||||
else:
|
||||
@@ -382,12 +383,13 @@ class LabelTensor(torch.Tensor):
|
||||
def sort_labels(self, dim=None):
|
||||
def argsort(lst):
|
||||
return sorted(range(len(lst)), key=lambda x: lst[x])
|
||||
|
||||
if dim is None:
|
||||
dim = self.tensor.ndim-1
|
||||
dim = self.tensor.ndim - 1
|
||||
labels = self.full_labels[dim]['dof']
|
||||
sorted_index = argsort(labels)
|
||||
indexer = [slice(None)] * self.tensor.ndim
|
||||
indexer[dim] = sorted_index
|
||||
new_labels = deepcopy(self.full_labels)
|
||||
new_labels[dim] = {'dof': sorted(labels), 'name': new_labels[dim]['name']}
|
||||
return LabelTensor(self.tensor[indexer], new_labels)
|
||||
return LabelTensor(self.tensor[indexer], new_labels)
|
||||
|
||||
Reference in New Issue
Block a user