from lightning import Trainer from ThermalSolver.module import GraphSolver from ThermalSolver.data_module import GraphDataModule from ThermalSolver.model.local_gno import GatingGNO from ThermalSolver.model.basic_gno import GNO def main(): trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1) data_module = GraphDataModule( hf_repo="SISSAmathLab/thermal-conduction", split_name="easy", train_size=0.8, val_size=0.1, test_size=0.1, batch_size=8, ) model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1) solver = GraphSolver(model) trainer.fit(solver, datamodule=data_module) print("Done!") if __name__ == "__main__": main()