new plotting strategy
This commit is contained in:
@@ -15,7 +15,7 @@ def import_class(class_path: str):
|
||||
return cls
|
||||
|
||||
|
||||
def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
||||
def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, cells, i, batch_idx):
|
||||
# print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape)
|
||||
for j in [0]:
|
||||
idx = (batch == j).nonzero(as_tuple=True)[0]
|
||||
@@ -28,7 +28,8 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
||||
folder = f"{batch_idx:02d}_images"
|
||||
if os.path.exists(folder) is False:
|
||||
os.makedirs(folder)
|
||||
tria = Triangulation(pos[:, 0], pos[:, 1])
|
||||
triangles = torch.vstack([cells[:, [0, 1, 2]], cells[:, [0, 2, 3]]])
|
||||
tria = Triangulation(pos[:, 0], pos[:, 1], triangles=triangles)
|
||||
plt.figure(figsize=(24, 6))
|
||||
# plt.subplot(1, 4, 1)
|
||||
# plt.tricontourf(tria, y.squeeze().numpy(), levels=100)
|
||||
@@ -38,51 +39,36 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
||||
# plt.savefig("test_scatter_step_before.png", dpi=72)
|
||||
# x = z
|
||||
plt.subplot(1, 4, 1)
|
||||
# plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100)
|
||||
plt.scatter(
|
||||
pos[:, 0],
|
||||
pos[:, 1],
|
||||
c=y_pred.squeeze().numpy(),
|
||||
s=20,
|
||||
cmap="viridis",
|
||||
)
|
||||
plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100)
|
||||
# plt.scatter(pos[:, 0], pos[:, 1], c=y_pred.squeeze().numpy(), s=20, cmap="viridis",)
|
||||
plt.colorbar()
|
||||
plt.title(f"Prediction at timestep {i:03d}")
|
||||
plt.subplot(1, 4, 2)
|
||||
# plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100)
|
||||
plt.scatter(
|
||||
pos[:, 0],
|
||||
pos[:, 1],
|
||||
c=y_true.squeeze().numpy(),
|
||||
s=20,
|
||||
cmap="viridis",
|
||||
)
|
||||
plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100)
|
||||
# plt.scatter(pos[:, 0], pos[:, 1], c=y_true.squeeze().numpy(), s=20, cmap="viridis")
|
||||
plt.colorbar()
|
||||
plt.title("Ground Truth Steady State")
|
||||
plt.subplot(1, 4, 3)
|
||||
per_element_relative_error = torch.abs(y_pred - y_true) / (y_true + 1e-6)
|
||||
# plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100)
|
||||
plt.scatter(
|
||||
pos[:, 0],
|
||||
pos[:, 1],
|
||||
c=per_element_relative_error.squeeze().numpy(),
|
||||
s=20,
|
||||
cmap="viridis",
|
||||
per_element_relative_error = torch.abs(y_pred - y_true) / (
|
||||
y_true + 1e-6
|
||||
)
|
||||
per_element_relative_error = torch.clamp(
|
||||
per_element_relative_error, max=1.0, min=0.0
|
||||
)
|
||||
plt.tricontourf(
|
||||
tria,
|
||||
per_element_relative_error.squeeze(),
|
||||
levels=100,
|
||||
vmin=0,
|
||||
vmax=1.0,
|
||||
)
|
||||
# plt.scatter(pos[:, 0], pos[:, 1], c=per_element_relative_error.squeeze().numpy(), s=20, cmap="viridis", vmin=0, vmax=1.0)
|
||||
plt.colorbar()
|
||||
plt.title("Relative Error")
|
||||
plt.subplot(1, 4, 4)
|
||||
absolute_error = torch.abs(y_pred - y_true)
|
||||
# plt.tricontourf(tria, absolute_error.squeeze(), levels=100)
|
||||
plt.scatter(
|
||||
pos[:, 0],
|
||||
pos[:, 1],
|
||||
c=absolute_error.squeeze().numpy(),
|
||||
s=20,
|
||||
cmap="viridis",
|
||||
)
|
||||
plt.tricontourf(tria, absolute_error.squeeze(), levels=100)
|
||||
# plt.scatter(pos[:, 0], pos[:, 1], c=absolute_error.squeeze().numpy(), s=20, cmap="viridis")
|
||||
plt.colorbar()
|
||||
plt.title("Absolute Error")
|
||||
plt.suptitle("GNO", fontsize=16)
|
||||
@@ -292,15 +278,9 @@ class GraphSolver(LightningModule):
|
||||
sequence_length = y.size(1)
|
||||
y = y[:, -1, :].unsqueeze(1)
|
||||
_plot_mesh(
|
||||
batch.pos,
|
||||
x,
|
||||
x,
|
||||
y[:, -1, :],
|
||||
batch.batch,
|
||||
0,
|
||||
batch_idx
|
||||
batch.pos, x, x, y[:, -1, :], batch.batch, batch.cells, 0, batch_idx
|
||||
)
|
||||
for i in range(100):
|
||||
for i in range(200):
|
||||
out = self._compute_model_steps(
|
||||
# torch.cat([x,pos], dim=-1),
|
||||
x,
|
||||
@@ -322,12 +302,15 @@ class GraphSolver(LightningModule):
|
||||
out,
|
||||
y[:, -1, :],
|
||||
batch.batch,
|
||||
batch.cells,
|
||||
i + 1,
|
||||
batch_idx
|
||||
batch_idx,
|
||||
)
|
||||
x = out
|
||||
loss = self.loss(out, y[:, -1, :])
|
||||
relative_error = torch.abs(out - y[:, -1, :]) / (torch.abs(y[:, -1, :]) + 1e-6)
|
||||
relative_error = torch.abs(out - y[:, -1, :]) / (
|
||||
torch.abs(y[:, -1, :]) + 1e-6
|
||||
)
|
||||
mean_relative_error = relative_error.mean()
|
||||
all_losses.append(mean_relative_error.item())
|
||||
losses.append(loss)
|
||||
@@ -345,7 +328,12 @@ class GraphSolver(LightningModule):
|
||||
|
||||
def on_test_end(self):
|
||||
if len(self.test_losses) > 0:
|
||||
_plot_losses(self.test_relative_errors, self.test_losses, self.test_relative_updates, batch_idx=0)
|
||||
_plot_losses(
|
||||
self.test_relative_errors,
|
||||
self.test_losses,
|
||||
self.test_relative_updates,
|
||||
batch_idx=0,
|
||||
)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
|
||||
|
||||
@@ -82,7 +82,9 @@ class GraphDataModule(LightningDataModule):
|
||||
conductivity = torch.tensor(
|
||||
snapshot["conductivity"], dtype=torch.float32
|
||||
)
|
||||
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)[:50]
|
||||
temperature = torch.tensor(
|
||||
snapshot["temperature"], dtype=torch.float32
|
||||
)[:50]
|
||||
|
||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||
|
||||
|
||||
@@ -117,7 +117,9 @@ class GraphDataModule(LightningDataModule):
|
||||
if not test:
|
||||
for t in range(1, temperatures.size(0)):
|
||||
diff = temperatures[t, :] - temperatures[t - 1, :]
|
||||
norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2)
|
||||
norm_diff = torch.norm(diff, p=2) / torch.norm(
|
||||
temperatures[t - 1], p=2
|
||||
)
|
||||
if norm_diff < self.min_normalized_diff:
|
||||
temperatures = temperatures[: t + 1, :]
|
||||
break
|
||||
@@ -148,6 +150,9 @@ class GraphDataModule(LightningDataModule):
|
||||
data = []
|
||||
|
||||
if test:
|
||||
cells = geometry.get("cells", None)
|
||||
if cells is not None:
|
||||
cells = torch.tensor(cells, dtype=torch.int64)
|
||||
data.append(
|
||||
MeshData(
|
||||
x=temperatures[0, :].unsqueeze(-1),
|
||||
@@ -158,6 +163,7 @@ class GraphDataModule(LightningDataModule):
|
||||
edge_attr=edge_attr,
|
||||
boundary_mask=boundary_mask,
|
||||
boundary_values=boundary_values,
|
||||
cells=cells,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user