Codacy correction
This commit is contained in:
committed by
Nicola Demo
parent
ea3d1924e7
commit
dd43c8304c
@@ -59,9 +59,11 @@ class BaseDataset(Dataset):
|
||||
keys = list(data.keys())
|
||||
if set(self.__slots__) == set(keys):
|
||||
self._populate_init_list(data)
|
||||
idx = [key for key, val in
|
||||
self.problem.collector.conditions_name.items() if
|
||||
val == name]
|
||||
idx = [
|
||||
key for key, val in
|
||||
self.problem.collector.conditions_name.items()
|
||||
if val == name
|
||||
]
|
||||
self.conditions_idx.append(idx)
|
||||
self.initialize()
|
||||
|
||||
@@ -89,15 +91,16 @@ class BaseDataset(Dataset):
|
||||
if isinstance(slot_data, (LabelTensor, torch.Tensor)):
|
||||
dims = len(slot_data.size())
|
||||
slot_data = slot_data.permute(
|
||||
[batching_dim] + [dim for dim in range(dims) if
|
||||
dim != batching_dim])
|
||||
[batching_dim] +
|
||||
[dim for dim in range(dims) if dim != batching_dim])
|
||||
if current_cond_num_el is None:
|
||||
current_cond_num_el = len(slot_data)
|
||||
elif current_cond_num_el != len(slot_data):
|
||||
raise ValueError('Different dimension in same condition')
|
||||
current_list = getattr(self, slot)
|
||||
current_list += [slot_data] if not (
|
||||
isinstance(slot_data, list)) else slot_data
|
||||
current_list += [
|
||||
slot_data
|
||||
] if not (isinstance(slot_data, list)) else slot_data
|
||||
self.num_el_per_condition.append(current_cond_num_el)
|
||||
|
||||
def initialize(self):
|
||||
@@ -108,14 +111,12 @@ class BaseDataset(Dataset):
|
||||
logging.debug(f'Initialize dataset {self.__class__.__name__}')
|
||||
|
||||
if self.num_el_per_condition:
|
||||
self.condition_indices = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * self.num_el_per_condition[i],
|
||||
dtype=torch.uint8)
|
||||
for i in range(len(self.num_el_per_condition))
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
self.condition_indices = torch.cat([
|
||||
torch.tensor([i] * self.num_el_per_condition[i],
|
||||
dtype=torch.uint8)
|
||||
for i in range(len(self.num_el_per_condition))
|
||||
],
|
||||
dim=0)
|
||||
for slot in self.__slots__:
|
||||
current_attribute = getattr(self, slot)
|
||||
if all(isinstance(a, LabelTensor) for a in current_attribute):
|
||||
|
||||
Reference in New Issue
Block a user