""" Custom Data/Batch per gestire bene le boundary conditions. """ from typing import List import torch from torch_geometric.data import Data, Batch B_KEYS: List[str] = ["boundary_mask"] class MeshData(Data): def __inc__(self, key, value, *args, **kwargs): # questi campi sono INDICI di nodi, quindi incrementali con num_nodes if key in B_KEYS: return self.num_nodes return super().__inc__(key, value, *args, **kwargs)