small fix
This commit is contained in:
@@ -6,6 +6,7 @@ from torch_geometric.data import Data
|
||||
from torch_geometric.loader import DataLoader
|
||||
from torch_geometric.utils import to_undirected
|
||||
from .mesh_data import MeshData
|
||||
import os
|
||||
|
||||
|
||||
class GraphDataModule(LightningDataModule):
|
||||
@@ -115,8 +116,8 @@ class GraphDataModule(LightningDataModule):
|
||||
pos=pos,
|
||||
edge_attr=edge_attr,
|
||||
y=temperature.unsqueeze(-1),
|
||||
boundary_mask=boundary_mask,
|
||||
boundary_values=torch.tensor(0),
|
||||
boundary_mask=torch.tensor(0), # Fake value (to fix)
|
||||
boundary_values=torch.tensor(0), # Fake value (to fix)
|
||||
)
|
||||
|
||||
return MeshData(
|
||||
@@ -143,15 +144,27 @@ class GraphDataModule(LightningDataModule):
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_data, batch_size=self.batch_size, shuffle=True
|
||||
self.train_data,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(
|
||||
self.val_data, batch_size=self.batch_size, shuffle=False
|
||||
self.val_data,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
self.test_data, batch_size=self.batch_size, shuffle=False
|
||||
self.test_data,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user