From a9d56a3ed9edfdcbc2a97ebdb01709f05d4fafdd Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Mon, 15 Dec 2025 09:08:21 +0100 Subject: [PATCH] fix model and datamodule --- ThermalSolver/autoregressive_module.py | 82 +++++++++++----------- ThermalSolver/graph_datamodule_unsteady.py | 37 +++++++--- 2 files changed, 69 insertions(+), 50 deletions(-) diff --git a/ThermalSolver/autoregressive_module.py b/ThermalSolver/autoregressive_module.py index 6f85528..f5254dd 100644 --- a/ThermalSolver/autoregressive_module.py +++ b/ThermalSolver/autoregressive_module.py @@ -17,7 +17,7 @@ def import_class(class_path: str): def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): # print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape) - for j in [0]: + for j in [0, 5, 10, 20]: idx = (batch == j).nonzero(as_tuple=True)[0] y = y_[idx].detach().cpu() y_pred = y_pred_[idx].detach().cpu() @@ -38,39 +38,37 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx): # plt.savefig("test_scatter_step_before.png", dpi=72) # x = z plt.subplot(1, 3, 1) - # plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) - plt.scatter( - pos[:, 0], - pos[:, 1], - c=y_pred.squeeze().numpy(), - s=20, - cmap="viridis", - ) + plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100) + # plt.scatter( + # pos[:, 0], + # pos[:, 1], + # c=y_pred.squeeze().numpy(), + # s=20, + # cmap="viridis", + # ) plt.colorbar() plt.title("Step t Predicted") plt.subplot(1, 3, 2) - # plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) - plt.scatter( - pos[:, 0], - pos[:, 1], - c=y_true.squeeze().numpy(), - s=20, - cmap="viridis", - ) + plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100) + # plt.scatter( + # pos[:, 0], + # pos[:, 1], + # c=y_true.squeeze().numpy(), + # s=20, + # cmap="viridis", + # ) plt.colorbar() plt.title("t True") plt.subplot(1, 3, 3) - per_element_relative_error = torch.abs(y_pred - y_true) / torch.clamp( - torch.abs(y_true), min=1e-6 - ) - # plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100) - plt.scatter( - pos[:, 0], - pos[:, 1], - c=per_element_relative_error.squeeze().numpy(), - s=20, - cmap="viridis", - ) + per_element_relative_error = torch.abs(y_pred - y_true) + plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100) + # plt.scatter( + # pos[:, 0], + # pos[:, 1], + # c=per_element_relative_error.squeeze().numpy(), + # s=20, + # cmap="viridis", + # ) plt.colorbar() plt.title("Relative Error") plt.suptitle("GNO", fontsize=16) @@ -216,20 +214,20 @@ class GraphSolver(LightningModule): batch.boundary_values, conductivity, ) - if ( - batch_idx == 0 - and self.current_epoch % 10 == 0 - and self.current_epoch > 0 - ): - _plot_mesh( - batch.pos, - x, - out, - y[:, i, :], - batch.batch, - i, - self.current_epoch, - ) + # if ( + # batch_idx == 0 + # and self.current_epoch % 10 == 0 + # and self.current_epoch > 0 + # ): + # _plot_mesh( + # batch.pos, + # x, + # out, + # y[:, i, :], + # batch.batch, + # i, + # self.current_epoch, + # ) x = out losses.append(self.loss(out, y[:, i, :])) diff --git a/ThermalSolver/graph_datamodule_unsteady.py b/ThermalSolver/graph_datamodule_unsteady.py index 4026a1d..9ec9df2 100644 --- a/ThermalSolver/graph_datamodule_unsteady.py +++ b/ThermalSolver/graph_datamodule_unsteady.py @@ -1,18 +1,19 @@ import torch from tqdm import tqdm from lightning import LightningDataModule -from datasets import load_dataset +from datasets import load_dataset, concatenate_datasets from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.utils import to_undirected from .mesh_data import MeshData +from typing import List, Union class GraphDataModule(LightningDataModule): def __init__( self, hf_repo: str, - split_name: str, + split_name: Union[str, List[str]], n_elements: int = None, train_size: float = 0.2, val_size: float = 0.1, @@ -44,8 +45,30 @@ class GraphDataModule(LightningDataModule): self.radius = radius def prepare_data(self): - dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name] - geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name] + if isinstance(self.split_name, list): + dataset_list = [] + geometry_list = [] + for split in self.split_name: + dataset_list.append( + load_dataset(self.hf_repo, name="snapshots")[split] + ) + geometry_list.append( + load_dataset(self.hf_repo, name="geometry")[split] + ) + + dataset = concatenate_datasets(dataset_list) + geometry = concatenate_datasets(geometry_list) + idx = torch.randperm(len(dataset)) + dataset = dataset.select(idx.tolist()) + geometry = geometry.select(idx.tolist()) + else: + dataset = load_dataset(self.hf_repo, name="snapshots")[ + self.split_name + ] + geometry = load_dataset(self.hf_repo, name="geometry")[ + self.split_name + ] + if self.n_elements is not None: dataset = dataset.select(range(self.n_elements)) geometry = geometry.select(range(self.n_elements)) @@ -86,7 +109,7 @@ class GraphDataModule(LightningDataModule): dim=0, ) ) - print(temperatures.shape) + # print(temperatures.shape) pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] @@ -103,9 +126,7 @@ class GraphDataModule(LightningDataModule): boundary_mask = torch.tensor( geometry["constraints_mask"], dtype=torch.int64 ) - boundary_values = torch.tensor( - geometry["constraints_values"], dtype=torch.float32 - ) + boundary_values = temperatures[0, boundary_mask] edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1) if self.remove_boundary_edges: