improve unrolling

This commit is contained in:
Filippo Olivo
2025-10-02 10:17:01 +02:00
parent c6c416e682
commit b07e305cb5
5 changed files with 322 additions and 105 deletions

View File

@@ -0,0 +1,17 @@
"""
Custom Data/Batch per gestire bene le boundary conditions.
"""
from typing import List
import torch
from torch_geometric.data import Data, Batch
B_KEYS: List[str] = ["boundary_mask"]
class MeshData(Data):
def __inc__(self, key, value, *args, **kwargs):
# questi campi sono INDICI di nodi, quindi incrementali con num_nodes
if key in B_KEYS:
return self.num_nodes
return super().__inc__(key, value, *args, **kwargs)