from ThermalSolver.data_module import GraphDataModule def test_graph_data_module(): data_module = GraphDataModule( hf_repo="SISSAmathLab/thermal-conduction", split_name="pytest", train_size=0.8, val_size=0.1, test_size=0.1, batch_size=32, ) data_module.prepare_data() data_module.setup("fit")