new plotting strategy
This commit is contained in:
@@ -15,7 +15,7 @@ def import_class(class_path: str):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, cells, i, batch_idx):
|
||||||
# print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape)
|
# print(pos_.shape, y_.shape, y_pred_.shape, y_true_.shape)
|
||||||
for j in [0]:
|
for j in [0]:
|
||||||
idx = (batch == j).nonzero(as_tuple=True)[0]
|
idx = (batch == j).nonzero(as_tuple=True)[0]
|
||||||
@@ -28,7 +28,8 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
|||||||
folder = f"{batch_idx:02d}_images"
|
folder = f"{batch_idx:02d}_images"
|
||||||
if os.path.exists(folder) is False:
|
if os.path.exists(folder) is False:
|
||||||
os.makedirs(folder)
|
os.makedirs(folder)
|
||||||
tria = Triangulation(pos[:, 0], pos[:, 1])
|
triangles = torch.vstack([cells[:, [0, 1, 2]], cells[:, [0, 2, 3]]])
|
||||||
|
tria = Triangulation(pos[:, 0], pos[:, 1], triangles=triangles)
|
||||||
plt.figure(figsize=(24, 6))
|
plt.figure(figsize=(24, 6))
|
||||||
# plt.subplot(1, 4, 1)
|
# plt.subplot(1, 4, 1)
|
||||||
# plt.tricontourf(tria, y.squeeze().numpy(), levels=100)
|
# plt.tricontourf(tria, y.squeeze().numpy(), levels=100)
|
||||||
@@ -38,51 +39,36 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
|||||||
# plt.savefig("test_scatter_step_before.png", dpi=72)
|
# plt.savefig("test_scatter_step_before.png", dpi=72)
|
||||||
# x = z
|
# x = z
|
||||||
plt.subplot(1, 4, 1)
|
plt.subplot(1, 4, 1)
|
||||||
# plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100)
|
plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=100)
|
||||||
plt.scatter(
|
# plt.scatter(pos[:, 0], pos[:, 1], c=y_pred.squeeze().numpy(), s=20, cmap="viridis",)
|
||||||
pos[:, 0],
|
|
||||||
pos[:, 1],
|
|
||||||
c=y_pred.squeeze().numpy(),
|
|
||||||
s=20,
|
|
||||||
cmap="viridis",
|
|
||||||
)
|
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title(f"Prediction at timestep {i:03d}")
|
plt.title(f"Prediction at timestep {i:03d}")
|
||||||
plt.subplot(1, 4, 2)
|
plt.subplot(1, 4, 2)
|
||||||
# plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100)
|
plt.tricontourf(tria, y_true.squeeze().numpy(), levels=100)
|
||||||
plt.scatter(
|
# plt.scatter(pos[:, 0], pos[:, 1], c=y_true.squeeze().numpy(), s=20, cmap="viridis")
|
||||||
pos[:, 0],
|
|
||||||
pos[:, 1],
|
|
||||||
c=y_true.squeeze().numpy(),
|
|
||||||
s=20,
|
|
||||||
cmap="viridis",
|
|
||||||
)
|
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title("Ground Truth Steady State")
|
plt.title("Ground Truth Steady State")
|
||||||
plt.subplot(1, 4, 3)
|
plt.subplot(1, 4, 3)
|
||||||
per_element_relative_error = torch.abs(y_pred - y_true) / (y_true + 1e-6)
|
per_element_relative_error = torch.abs(y_pred - y_true) / (
|
||||||
# plt.tricontourf(tria, per_element_relative_error.squeeze(), levels=100)
|
y_true + 1e-6
|
||||||
plt.scatter(
|
)
|
||||||
pos[:, 0],
|
per_element_relative_error = torch.clamp(
|
||||||
pos[:, 1],
|
per_element_relative_error, max=1.0, min=0.0
|
||||||
c=per_element_relative_error.squeeze().numpy(),
|
)
|
||||||
s=20,
|
plt.tricontourf(
|
||||||
cmap="viridis",
|
tria,
|
||||||
|
per_element_relative_error.squeeze(),
|
||||||
|
levels=100,
|
||||||
vmin=0,
|
vmin=0,
|
||||||
vmax=1.0,
|
vmax=1.0,
|
||||||
)
|
)
|
||||||
|
# plt.scatter(pos[:, 0], pos[:, 1], c=per_element_relative_error.squeeze().numpy(), s=20, cmap="viridis", vmin=0, vmax=1.0)
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title("Relative Error")
|
plt.title("Relative Error")
|
||||||
plt.subplot(1, 4, 4)
|
plt.subplot(1, 4, 4)
|
||||||
absolute_error = torch.abs(y_pred - y_true)
|
absolute_error = torch.abs(y_pred - y_true)
|
||||||
# plt.tricontourf(tria, absolute_error.squeeze(), levels=100)
|
plt.tricontourf(tria, absolute_error.squeeze(), levels=100)
|
||||||
plt.scatter(
|
# plt.scatter(pos[:, 0], pos[:, 1], c=absolute_error.squeeze().numpy(), s=20, cmap="viridis")
|
||||||
pos[:, 0],
|
|
||||||
pos[:, 1],
|
|
||||||
c=absolute_error.squeeze().numpy(),
|
|
||||||
s=20,
|
|
||||||
cmap="viridis",
|
|
||||||
)
|
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title("Absolute Error")
|
plt.title("Absolute Error")
|
||||||
plt.suptitle("GNO", fontsize=16)
|
plt.suptitle("GNO", fontsize=16)
|
||||||
@@ -292,15 +278,9 @@ class GraphSolver(LightningModule):
|
|||||||
sequence_length = y.size(1)
|
sequence_length = y.size(1)
|
||||||
y = y[:, -1, :].unsqueeze(1)
|
y = y[:, -1, :].unsqueeze(1)
|
||||||
_plot_mesh(
|
_plot_mesh(
|
||||||
batch.pos,
|
batch.pos, x, x, y[:, -1, :], batch.batch, batch.cells, 0, batch_idx
|
||||||
x,
|
)
|
||||||
x,
|
for i in range(200):
|
||||||
y[:, -1, :],
|
|
||||||
batch.batch,
|
|
||||||
0,
|
|
||||||
batch_idx
|
|
||||||
)
|
|
||||||
for i in range(100):
|
|
||||||
out = self._compute_model_steps(
|
out = self._compute_model_steps(
|
||||||
# torch.cat([x,pos], dim=-1),
|
# torch.cat([x,pos], dim=-1),
|
||||||
x,
|
x,
|
||||||
@@ -322,12 +302,15 @@ class GraphSolver(LightningModule):
|
|||||||
out,
|
out,
|
||||||
y[:, -1, :],
|
y[:, -1, :],
|
||||||
batch.batch,
|
batch.batch,
|
||||||
i+1,
|
batch.cells,
|
||||||
batch_idx
|
i + 1,
|
||||||
|
batch_idx,
|
||||||
)
|
)
|
||||||
x = out
|
x = out
|
||||||
loss = self.loss(out, y[:, -1, :])
|
loss = self.loss(out, y[:, -1, :])
|
||||||
relative_error = torch.abs(out - y[:, -1, :]) / (torch.abs(y[:, -1, :]) + 1e-6)
|
relative_error = torch.abs(out - y[:, -1, :]) / (
|
||||||
|
torch.abs(y[:, -1, :]) + 1e-6
|
||||||
|
)
|
||||||
mean_relative_error = relative_error.mean()
|
mean_relative_error = relative_error.mean()
|
||||||
all_losses.append(mean_relative_error.item())
|
all_losses.append(mean_relative_error.item())
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
@@ -345,7 +328,12 @@ class GraphSolver(LightningModule):
|
|||||||
|
|
||||||
def on_test_end(self):
|
def on_test_end(self):
|
||||||
if len(self.test_losses) > 0:
|
if len(self.test_losses) > 0:
|
||||||
_plot_losses(self.test_relative_errors, self.test_losses, self.test_relative_updates, batch_idx=0)
|
_plot_losses(
|
||||||
|
self.test_relative_errors,
|
||||||
|
self.test_losses,
|
||||||
|
self.test_relative_updates,
|
||||||
|
batch_idx=0,
|
||||||
|
)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
|
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ class GraphDataModule(LightningDataModule):
|
|||||||
conductivity = torch.tensor(
|
conductivity = torch.tensor(
|
||||||
snapshot["conductivity"], dtype=torch.float32
|
snapshot["conductivity"], dtype=torch.float32
|
||||||
)
|
)
|
||||||
temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32)[:50]
|
temperature = torch.tensor(
|
||||||
|
snapshot["temperature"], dtype=torch.float32
|
||||||
|
)[:50]
|
||||||
|
|
||||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||||
|
|
||||||
|
|||||||
@@ -117,7 +117,9 @@ class GraphDataModule(LightningDataModule):
|
|||||||
if not test:
|
if not test:
|
||||||
for t in range(1, temperatures.size(0)):
|
for t in range(1, temperatures.size(0)):
|
||||||
diff = temperatures[t, :] - temperatures[t - 1, :]
|
diff = temperatures[t, :] - temperatures[t - 1, :]
|
||||||
norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2)
|
norm_diff = torch.norm(diff, p=2) / torch.norm(
|
||||||
|
temperatures[t - 1], p=2
|
||||||
|
)
|
||||||
if norm_diff < self.min_normalized_diff:
|
if norm_diff < self.min_normalized_diff:
|
||||||
temperatures = temperatures[: t + 1, :]
|
temperatures = temperatures[: t + 1, :]
|
||||||
break
|
break
|
||||||
@@ -148,6 +150,9 @@ class GraphDataModule(LightningDataModule):
|
|||||||
data = []
|
data = []
|
||||||
|
|
||||||
if test:
|
if test:
|
||||||
|
cells = geometry.get("cells", None)
|
||||||
|
if cells is not None:
|
||||||
|
cells = torch.tensor(cells, dtype=torch.int64)
|
||||||
data.append(
|
data.append(
|
||||||
MeshData(
|
MeshData(
|
||||||
x=temperatures[0, :].unsqueeze(-1),
|
x=temperatures[0, :].unsqueeze(-1),
|
||||||
@@ -158,6 +163,7 @@ class GraphDataModule(LightningDataModule):
|
|||||||
edge_attr=edge_attr,
|
edge_attr=edge_attr,
|
||||||
boundary_mask=boundary_mask,
|
boundary_mask=boundary_mask,
|
||||||
boundary_values=boundary_values,
|
boundary_values=boundary_values,
|
||||||
|
cells=cells,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|||||||
Reference in New Issue
Block a user