24 lines
747 B
Python
24 lines
747 B
Python
from lightning import Trainer
|
|
from ThermalSolver.module import GraphSolver
|
|
from ThermalSolver.data_module import GraphDataModule
|
|
from ThermalSolver.model.local_gno import GatingGNO
|
|
from ThermalSolver.model.basic_gno import GNO
|
|
|
|
|
|
def main():
|
|
trainer = Trainer(max_epochs=100, accelerator="cuda", devices=1)
|
|
data_module = GraphDataModule(
|
|
hf_repo="SISSAmathLab/thermal-conduction",
|
|
split_name="easy",
|
|
train_size=0.8,
|
|
val_size=0.1,
|
|
test_size=0.1,
|
|
batch_size=8,
|
|
)
|
|
model = GatingGNO(x_ch_node=1, f_ch_node=1, hidden=16, layers=8, edge_ch=3, out_ch=1)
|
|
solver = GraphSolver(model)
|
|
trainer.fit(solver, datamodule=data_module)
|
|
print("Done!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |