diff --git a/experiments/config.yaml b/experiments/config.yaml new file mode 100644 index 0000000..c0b5602 --- /dev/null +++ b/experiments/config.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.5.5 +seed_everything: true +trainer: + accelerator: gpu + strategy: auto + devices: 1 + num_nodes: 1 + precision: null + logger: null + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/loss + mode: min + save_top_k: 1 + filename: best-checkpoint + max_epochs: 50 + min_epochs: null + max_steps: -1 + min_steps: null + overfit_batches: 0.0 + log_every_n_steps: null + inference_mode: true + default_root_dir: null +model: + class_path: ThermalSolver.module.GraphSolver + init_args: + model_class_path: ThermalSolver.model.local_gno.GatingGNO + model_init_args: + x_ch_node: 1 + f_ch_node: 1 + hidden: 16 + layers: 2 + edge_ch: 3 + out_ch: 1 + unrolling_steps: 10 +data: + class_path: ThermalSolver.data_module.GraphDataModule + init_args: + hf_repo: "SISSAmathLab/thermal-conduction" + split_name: "2000" + batch_size: 6 + train_size: 0.8 + test_size: 0.1 + test_size: 0.1 +optimizer: null +lr_scheduler: null +ckpt_path: null diff --git a/run.py b/run.py index 4d31328..368e10e 100644 --- a/run.py +++ b/run.py @@ -1,38 +1,12 @@ import torch -from lightning import Trainer -from ThermalSolver.module import GraphSolver -from ThermalSolver.data_module import GraphDataModule -from ThermalSolver.model.local_gno import GatingGNO +from lightning.pytorch.cli import LightningCLI + +torch.set_float32_matmul_precision("medium") def main(): - trainer = Trainer( - max_epochs=50, accelerator="cuda", devices=1, accumulate_grad_batches=3 - ) - data_module = GraphDataModule( - hf_repo="SISSAmathLab/thermal-conduction", - split_name="2000", - train_size=0.8, - val_size=0.1, - test_size=0.1, - batch_size=10, - ) - data_module.prepare_data() - data_module.setup("fit") - model = GatingGNO( - x_ch_node=1, f_ch_node=1, hidden=16, layers=2, edge_ch=3, out_ch=1 - ) - solver = GraphSolver(model, unrolling_steps=64) - - trainer.fit( - solver, - train_dataloaders=data_module.train_dataloader(), - val_dataloaders=data_module.val_dataloader(), - ) - data_module.setup("test") - trainer.test(solver, dataloaders=data_module.test_dataloader()) + LightningCLI(subclass_mode_data=True, subclass_mode_model=True) if __name__ == "__main__": - torch.set_float32_matmul_precision("medium") main()