Implement PinaGraphDataset
This commit is contained in:
committed by
Nicola Demo
parent
86fe41261b
commit
4c5e1569ff
@@ -93,8 +93,7 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
|
|
||||||
|
|
||||||
class PinaGraphDataset(PinaDataset):
|
class PinaGraphDataset(PinaDataset):
|
||||||
pass
|
|
||||||
'''
|
|
||||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||||
automatic_batching):
|
automatic_batching):
|
||||||
super().__init__(conditions_dict, max_conditions_lengths)
|
super().__init__(conditions_dict, max_conditions_lengths)
|
||||||
@@ -113,7 +112,7 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
|
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
|
||||||
for i in cond_idx])
|
for i in cond_idx])
|
||||||
if isinstance(v, list)
|
if isinstance(v, list)
|
||||||
else v[cond_idx]
|
else v[cond_idx].reshape(-1, *v[cond_idx].shape[2:])
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
}
|
}
|
||||||
return to_return_dict
|
return to_return_dict
|
||||||
@@ -132,5 +131,4 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
return self.fetch_from_idx_list(index)
|
return self.fetch_from_idx_list(index)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self._getitem_func(idx)
|
return self._getitem_func(idx)
|
||||||
'''
|
|
||||||
Reference in New Issue
Block a user