new plotting strategy
This commit is contained in:
@@ -117,7 +117,9 @@ class GraphDataModule(LightningDataModule):
|
||||
if not test:
|
||||
for t in range(1, temperatures.size(0)):
|
||||
diff = temperatures[t, :] - temperatures[t - 1, :]
|
||||
norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2)
|
||||
norm_diff = torch.norm(diff, p=2) / torch.norm(
|
||||
temperatures[t - 1], p=2
|
||||
)
|
||||
if norm_diff < self.min_normalized_diff:
|
||||
temperatures = temperatures[: t + 1, :]
|
||||
break
|
||||
@@ -148,6 +150,9 @@ class GraphDataModule(LightningDataModule):
|
||||
data = []
|
||||
|
||||
if test:
|
||||
cells = geometry.get("cells", None)
|
||||
if cells is not None:
|
||||
cells = torch.tensor(cells, dtype=torch.int64)
|
||||
data.append(
|
||||
MeshData(
|
||||
x=temperatures[0, :].unsqueeze(-1),
|
||||
@@ -158,6 +163,7 @@ class GraphDataModule(LightningDataModule):
|
||||
edge_attr=edge_attr,
|
||||
boundary_mask=boundary_mask,
|
||||
boundary_values=boundary_values,
|
||||
cells=cells,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user