add datamodule
This commit is contained in:
13
tests/test_datamodule.py
Normal file
13
tests/test_datamodule.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from ThermalSolver.data_module import GraphDataModule
|
||||
|
||||
def test_graph_data_module():
|
||||
data_module = GraphDataModule(
|
||||
hf_repo="SISSAmathLab/thermal-conduction",
|
||||
split_name="pytest",
|
||||
train_size=0.8,
|
||||
val_size=0.1,
|
||||
test_size=0.1,
|
||||
batch_size=32,
|
||||
)
|
||||
data_module.prepare_data()
|
||||
data_module.setup("fit")
|
||||
Reference in New Issue
Block a user