fix
This commit is contained in:
@@ -6,6 +6,12 @@ from torch_geometric.data import Data
|
||||
from ..graph import Graph, LabelBatch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
STACK_FN_MAP = {
|
||||
"label_tensor": LabelTensor.stack,
|
||||
"tensor": torch.stack,
|
||||
"data": LabelBatch.from_data_list,
|
||||
}
|
||||
|
||||
|
||||
class PinaDatasetFactory:
|
||||
"""
|
||||
@@ -65,18 +71,18 @@ class PinaDataset(Dataset):
|
||||
self.automatic_batching = (
|
||||
automatic_batching if automatic_batching is not None else True
|
||||
)
|
||||
self.stack_fn = {}
|
||||
self._stack_fn = {}
|
||||
self.is_graph_dataset = False
|
||||
# Determine stacking functions for each data type (used in collate_fn)
|
||||
for k, v in data_dict.items():
|
||||
if isinstance(v, LabelTensor):
|
||||
self.stack_fn[k] = LabelTensor.stack
|
||||
self._stack_fn[k] = "label_tensor"
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self.stack_fn[k] = torch.stack
|
||||
self._stack_fn[k] = "tensor"
|
||||
elif isinstance(v, list) and all(
|
||||
isinstance(item, (Data, Graph)) for item in v
|
||||
):
|
||||
self.stack_fn[k] = LabelBatch.from_data_list
|
||||
self._stack_fn[k] = "data"
|
||||
self.is_graph_dataset = True
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -84,6 +90,11 @@ class PinaDataset(Dataset):
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the length of the dataset.
|
||||
:return: The length of the dataset.
|
||||
:rtype: int
|
||||
"""
|
||||
return len(next(iter(self.data.values())))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@@ -113,10 +124,9 @@ class PinaDataset(Dataset):
|
||||
|
||||
to_return = {}
|
||||
for field_name, data in self.data.items():
|
||||
if self.stack_fn[field_name] is LabelBatch.from_data_list:
|
||||
to_return[field_name] = self.stack_fn[field_name](
|
||||
[data[i] for i in idx_list]
|
||||
)
|
||||
if self._stack_fn[field_name] == "data":
|
||||
fn = STACK_FN_MAP[self._stack_fn[field_name]]
|
||||
to_return[field_name] = fn([data[i] for i in idx_list])
|
||||
else:
|
||||
to_return[field_name] = data[idx_list]
|
||||
return to_return
|
||||
@@ -148,3 +158,14 @@ class PinaDataset(Dataset):
|
||||
:rtype: torch.Tensor | LabelTensor | Data | Graph
|
||||
"""
|
||||
return self.data["input"]
|
||||
|
||||
@property
|
||||
def stack_fn(self):
|
||||
"""
|
||||
Get the mapping of stacking functions for each data type in the dataset.
|
||||
|
||||
:return: A dictionary mapping condition names to their respective
|
||||
stacking function identifiers.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()}
|
||||
|
||||
Reference in New Issue
Block a user