small fix

This commit is contained in:
Filippo Olivo
2025-10-05 10:36:23 +02:00
parent 7a6fbdb89c
commit 469b1c6e13
3 changed files with 142 additions and 40 deletions

View File

@@ -6,6 +6,7 @@ from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected
from .mesh_data import MeshData
import os
class GraphDataModule(LightningDataModule):
@@ -115,8 +116,8 @@ class GraphDataModule(LightningDataModule):
pos=pos,
edge_attr=edge_attr,
y=temperature.unsqueeze(-1),
boundary_mask=boundary_mask,
boundary_values=torch.tensor(0),
boundary_mask=torch.tensor(0), # Fake value (to fix)
boundary_values=torch.tensor(0), # Fake value (to fix)
)
return MeshData(
@@ -143,15 +144,27 @@ class GraphDataModule(LightningDataModule):
def train_dataloader(self):
return DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True
self.train_data,
batch_size=self.batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
)
def val_dataloader(self):
return DataLoader(
self.val_data, batch_size=self.batch_size, shuffle=False
self.val_data,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
)
def test_dataloader(self):
return DataLoader(
self.test_data, batch_size=self.batch_size, shuffle=False
self.test_data,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
)