import torch from lightning import Trainer from ThermalSolver.module import GraphSolver from ThermalSolver.data_module import GraphDataModule from ThermalSolver.model.local_gno import GatingGNO def main(): trainer = Trainer( max_epochs=50, accelerator="cuda", devices=1, accumulate_grad_batches=3 ) data_module = GraphDataModule( hf_repo="SISSAmathLab/thermal-conduction", split_name="2000", train_size=0.8, val_size=0.1, test_size=0.1, batch_size=10, ) data_module.prepare_data() data_module.setup("fit") model = GatingGNO( x_ch_node=1, f_ch_node=1, hidden=16, layers=2, edge_ch=3, out_ch=1 ) solver = GraphSolver(model, unrolling_steps=64) trainer.fit( solver, train_dataloaders=data_module.train_dataloader(), val_dataloaders=data_module.val_dataloader(), ) data_module.setup("test") trainer.test(solver, dataloaders=data_module.test_dataloader()) if __name__ == "__main__": torch.set_float32_matmul_precision("medium") main()