diff --git a/run.py b/run.py index e9bc24c..4d31328 100644 --- a/run.py +++ b/run.py @@ -1,24 +1,38 @@ +import torch from lightning import Trainer from ThermalSolver.module import GraphSolver from ThermalSolver.data_module import GraphDataModule from ThermalSolver.model.local_gno import GatingGNO -from ThermalSolver.model.basic_gno import GNO def main(): - trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1) + trainer = Trainer( + max_epochs=50, accelerator="cuda", devices=1, accumulate_grad_batches=3 + ) data_module = GraphDataModule( hf_repo="SISSAmathLab/thermal-conduction", - split_name="easy", + split_name="2000", train_size=0.8, val_size=0.1, test_size=0.1, - batch_size=8, + batch_size=10, ) - model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1) - solver = GraphSolver(model) - trainer.fit(solver, datamodule=data_module) - print("Done!") + data_module.prepare_data() + data_module.setup("fit") + model = GatingGNO( + x_ch_node=1, f_ch_node=1, hidden=16, layers=2, edge_ch=3, out_ch=1 + ) + solver = GraphSolver(model, unrolling_steps=64) + + trainer.fit( + solver, + train_dataloaders=data_module.train_dataloader(), + val_dataloaders=data_module.val_dataloader(), + ) + data_module.setup("test") + trainer.test(solver, dataloaders=data_module.test_dataloader()) + if __name__ == "__main__": - main() \ No newline at end of file + torch.set_float32_matmul_precision("medium") + main()