Files
thermal-conduction-ml/ThermalSolver/mesh_data.py
2025-10-02 10:17:01 +02:00

18 lines
463 B
Python

"""
Custom Data/Batch per gestire bene le boundary conditions.
"""
from typing import List
import torch
from torch_geometric.data import Data, Batch
B_KEYS: List[str] = ["boundary_mask"]
class MeshData(Data):
def __inc__(self, key, value, *args, **kwargs):
# questi campi sono INDICI di nodi, quindi incrementali con num_nodes
if key in B_KEYS:
return self.num_nodes
return super().__inc__(key, value, *args, **kwargs)