improve unrolling
This commit is contained in:
17
ThermalSolver/mesh_data.py
Normal file
17
ThermalSolver/mesh_data.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Custom Data/Batch per gestire bene le boundary conditions.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
from torch_geometric.data import Data, Batch
|
||||
|
||||
B_KEYS: List[str] = ["boundary_mask"]
|
||||
|
||||
|
||||
class MeshData(Data):
|
||||
def __inc__(self, key, value, *args, **kwargs):
|
||||
# questi campi sono INDICI di nodi, quindi incrementali con num_nodes
|
||||
if key in B_KEYS:
|
||||
return self.num_nodes
|
||||
return super().__inc__(key, value, *args, **kwargs)
|
||||
Reference in New Issue
Block a user