13 lines
333 B
Python
13 lines
333 B
Python
from ThermalSolver.data_module import GraphDataModule
|
|
|
|
def test_graph_data_module():
|
|
data_module = GraphDataModule(
|
|
hf_repo="SISSAmathLab/thermal-conduction",
|
|
split_name="pytest",
|
|
train_size=0.8,
|
|
val_size=0.1,
|
|
test_size=0.1,
|
|
batch_size=32,
|
|
)
|
|
data_module.prepare_data()
|
|
data_module.setup("fit") |