transfer files

This commit is contained in:
Filippo Olivo
2025-11-25 19:19:31 +01:00
parent edba700d2a
commit 88bc5c05e4
13 changed files with 926 additions and 163 deletions

View File

@@ -5,7 +5,7 @@ import importlib
from matplotlib import pyplot as plt
from matplotlib.tri import Triangulation
from .model.finite_difference import FiniteDifferenceStep
import os
def import_class(class_path: str):
module_path, class_name = class_path.rsplit(".", 1) # split last dot
@@ -14,13 +14,15 @@ def import_class(class_path: str):
return cls
def _plot_mesh(pos, y, y_pred, batch, i):
def _plot_mesh(pos, y, y_pred, batch, i, batch_idx):
idx = batch == 0
y = y[idx].detach().cpu()
y_pred = y_pred[idx].detach().cpu()
pos = pos[idx].detach().cpu()
folder = f"{batch_idx:02d}_images"
if os.path.exists(folder) is False:
os.makedirs(folder)
pos = pos.detach().cpu()
tria = Triangulation(pos[:, 0], pos[:, 1])
plt.figure(figsize=(18, 5))
@@ -37,10 +39,23 @@ def _plot_mesh(pos, y, y_pred, batch, i):
plt.colorbar()
plt.title("Error")
plt.suptitle("GNO", fontsize=16)
name = f"images/graph_iter_{i:04d}.png"
name = f"{folder}/graph_iter_{i:04d}.png"
plt.savefig(name, dpi=72)
plt.close()
def _plot_losses(losses, batch_idx):
folder = f"{batch_idx:02d}_images"
plt.figure()
plt.plot(losses)
plt.yscale("log")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Test Loss over Iterations")
plt.grid(True)
file_name = f"{folder}/test_loss.png"
plt.savefig(file_name, dpi=300)
plt.close()
class GraphSolver(LightningModule):
def __init__(
@@ -231,7 +246,6 @@ class GraphSolver(LightningModule):
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
deg = self._compute_deg(edge_index, edge_attr, x.size(0))
for i in range(self.current_iters):
out = self._compute_model_steps(
x,
@@ -257,36 +271,8 @@ class GraphSolver(LightningModule):
batch_size=int(batch.num_graphs),
)
def test_step(self, batch: Batch, _):
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
deg = self._compute_deg(edge_index, edge_attr, x.size(0))
for i in range(self.max_iters):
out = self._compute_model_steps(
x,
edge_index,
edge_attr.unsqueeze(-1),
deg,
batch.boundary_mask,
batch.boundary_values,
)
converged = self._check_convergence(out, x)
# _plot_mesh(batch.pos, y, out, batch.batch, i)
if converged:
break
x = out
loss = self.loss(out, y)
self._log_loss(loss, batch, "test")
self.log(
"test/iterations",
i + 1,
on_step=False,
on_epoch=True,
prog_bar=True,
batch_size=int(batch.num_graphs),
)
def test_step(self, batch: Batch, batch_idx):
pass
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)