fix model

This commit is contained in:
FilippoOlivo
2025-12-01 14:55:13 +01:00
parent c36c59d08d
commit 54bebf7154
5 changed files with 167 additions and 88 deletions

View File

@@ -7,6 +7,7 @@ from matplotlib.tri import Triangulation
from .model.finite_difference import FiniteDifferenceStep
import os
def import_class(class_path: str):
module_path, class_name = class_path.rsplit(".", 1) # split last dot
module = importlib.import_module(module_path) # import the module
@@ -14,7 +15,7 @@ def import_class(class_path: str):
return cls
def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
for j in [0, 10, 20, 30]:
idx = (batch == j).nonzero(as_tuple=True)[0]
y = y_[idx].detach().cpu()
@@ -49,6 +50,7 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
plt.savefig(name, dpi=72)
plt.close()
def _plot_losses(losses, batch_idx):
folder = f"{batch_idx:02d}_images"
plt.figure()
@@ -74,8 +76,8 @@ class GraphSolver(LightningModule):
super().__init__()
self.model = import_class(model_class_path)(**model_init_args)
# for param in self.model.parameters():
# print(f"Param: {param.shape}, Grad: {param.grad}")
# print(f"Param: {param[0]}")
# print(f"Param: {param.shape}, Grad: {param.grad}")
# print(f"Param: {param[0]}")
self.loss = loss if loss is not None else torch.nn.MSELoss()
self.unrolling_steps = unrolling_steps
@@ -101,29 +103,36 @@ class GraphSolver(LightningModule):
return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze()
def _compute_model_steps(
self, x, edge_index, edge_attr, boundary_mask, boundary_values
):
out = self.model(x, edge_index, edge_attr)
self,
x,
edge_index,
edge_attr,
boundary_mask,
boundary_values,
conductivity,
):
out = self.model(x, edge_index, edge_attr, conductivity)
out[boundary_mask] = boundary_values.unsqueeze(-1)
# print(torch.min(out), torch.max(out))
return out
def _preprocess_batch(self, batch: Batch):
x, y, c, edge_index, edge_attr = (
x, y, c, edge_index, edge_attr, nodal_area = (
batch.x,
batch.y,
batch.c,
batch.edge_index,
batch.edge_attr,
batch.nodal_area,
)
edge_attr = 1 / edge_attr
c_ij = self._compute_c_ij(c, edge_index)
edge_attr = edge_attr * c_ij
# edge_attr = edge_attr / torch.max(edge_attr)
return x, y, edge_index, edge_attr
conductivity = self._compute_c_ij(c, edge_index)
edge_attr = edge_attr * conductivity
return x, y, edge_index, edge_attr, conductivity
def training_step(self, batch: Batch):
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
batch
)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = []
# print(x.shape, y.shape)
@@ -160,12 +169,13 @@ class GraphSolver(LightningModule):
# deg,
batch.boundary_mask,
batch.boundary_values,
conductivity,
)
x = out
# print(out.shape, y[:, i, :].shape)
losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
# print(self.model.scale_edge_attr.item())
loss = torch.stack(losses).mean()
# for param in self.model.parameters():
# print(f"Param: {param.shape}, Grad: {param.grad}")
@@ -173,26 +183,40 @@ class GraphSolver(LightningModule):
self._log_loss(loss, batch, "train")
return loss
def validation_step(self, batch: Batch, batch_idx):
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
batch
)
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = []
pos = batch.pos
for i in range(self.unrolling_steps):
out = self._compute_model_steps(
# torch.cat([x,pos], dim=-1),
x,
edge_index,
edge_attr,
# deg,
batch.boundary_mask,
batch.boundary_values,
# torch.cat([x,pos], dim=-1),
x,
edge_index,
edge_attr,
# deg,
batch.boundary_mask,
batch.boundary_values,
conductivity,
)
if (batch_idx == 0 and self.current_epoch % 10 == 0 and self.current_epoch > 20):
_plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch)
if (
batch_idx == 0
and self.current_epoch % 10 == 0
and self.current_epoch > 0
):
_plot_mesh(
batch.pos,
x,
out,
y[:, i, :],
batch.batch,
i,
self.current_epoch,
)
x = out
losses.append(self.loss(out , y[:, i, :]))
losses.append(self.loss(out, y[:, i, :]))
loss = torch.stack(losses).mean()
self._log_loss(loss, batch, "val")
@@ -202,5 +226,5 @@ class GraphSolver(LightningModule):
pass
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=5e-3)
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
return optimizer