Files
thermal-conduction-ml/run.py
2025-09-25 14:44:39 +02:00

24 lines
747 B
Python

from lightning import Trainer
from ThermalSolver.module import GraphSolver
from ThermalSolver.data_module import GraphDataModule
from ThermalSolver.model.local_gno import GatingGNO
from ThermalSolver.model.basic_gno import GNO
def main():
trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1)
data_module = GraphDataModule(
hf_repo="SISSAmathLab/thermal-conduction",
split_name="easy",
train_size=0.8,
val_size=0.1,
test_size=0.1,
batch_size=8,
)
model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1)
solver = GraphSolver(model)
trainer.fit(solver, datamodule=data_module)
print("Done!")
if __name__ == "__main__":
main()