new plotting strategy

This commit is contained in:
2025-12-19 15:50:47 +01:00
parent 68a7def5e6
commit 92104a6b06
3 changed files with 45 additions and 49 deletions

View File

@@ -117,7 +117,9 @@ class GraphDataModule(LightningDataModule):
if not test:
for t in range(1, temperatures.size(0)):
diff = temperatures[t, :] - temperatures[t - 1, :]
norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2)
norm_diff = torch.norm(diff, p=2) / torch.norm(
temperatures[t - 1], p=2
)
if norm_diff < self.min_normalized_diff:
temperatures = temperatures[: t + 1, :]
break
@@ -148,6 +150,9 @@ class GraphDataModule(LightningDataModule):
data = []
if test:
cells = geometry.get("cells", None)
if cells is not None:
cells = torch.tensor(cells, dtype=torch.int64)
data.append(
MeshData(
x=temperatures[0, :].unsqueeze(-1),
@@ -158,6 +163,7 @@ class GraphDataModule(LightningDataModule):
edge_attr=edge_attr,
boundary_mask=boundary_mask,
boundary_values=boundary_values,
cells=cells,
)
)
return data