add final loss and change model

This commit is contained in:
Filippo Olivo
2025-11-13 16:18:54 +01:00
parent dc59114f4a
commit ea9cf7c57c
2 changed files with 26 additions and 31 deletions

View File

@@ -65,7 +65,7 @@ class GraphSolver(LightningModule):
self.current_iters = start_iters self.current_iters = start_iters
self.accumulation_iters = accumulation_iters self.accumulation_iters = accumulation_iters
self.automatic_optimization = False self.automatic_optimization = False
self.threshold = 1e-2 self.threshold = 1e-5
def _compute_deg(self, edge_index, edge_attr, num_nodes): def _compute_deg(self, edge_index, edge_attr, num_nodes):
deg = torch.zeros(num_nodes, device=edge_index.device) deg = torch.zeros(num_nodes, device=edge_index.device)
@@ -102,14 +102,14 @@ class GraphSolver(LightningModule):
def _check_convergence(self, out, x): def _check_convergence(self, out, x):
residual_norm = torch.norm(out - x) residual_norm = torch.norm(out - x)
if residual_norm < self.threshold: if residual_norm < self.threshold * torch.norm(x):
return True return True
return False return False
def accumulate_gradients(self, losses, max_acc_iters): def accumulate_gradients(self, losses, max_acc_iters):
loss_ = torch.stack(losses, dim=0).mean() loss_ = torch.stack(losses, dim=0).mean()
loss = loss_ / self.accumulation_iters loss = 0.5 * loss_ / self.accumulation_iters
self.manual_backward(loss / max_acc_iters) self.manual_backward(loss / max_acc_iters, retain_graph=True)
return loss_.item() return loss_.item()
def _preprocess_batch(self, batch: Batch): def _preprocess_batch(self, batch: Batch):
@@ -177,7 +177,9 @@ class GraphSolver(LightningModule):
acc_loss = acc_loss + loss acc_loss = acc_loss + loss
x = out x = out
if i % self.accumulation_iters != 0:
loss = self.loss(out, y)
loss.backward()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
@@ -190,6 +192,15 @@ class GraphSolver(LightningModule):
prog_bar=True, prog_bar=True,
batch_size=int(batch.num_graphs), batch_size=int(batch.num_graphs),
) )
if hasattr(self.model, "p"):
self.log(
"train/p",
self.model.p,
on_step=False,
on_epoch=True,
prog_bar=True,
batch_size=int(batch.num_graphs),
)
def on_train_epoch_end(self): def on_train_epoch_end(self):
if self.curriculum_learning: if self.curriculum_learning:
@@ -220,7 +231,7 @@ class GraphSolver(LightningModule):
if converged: if converged:
break break
x = out x = out
loss = self.loss(out, y) loss = 0.5 * self.loss(out, y)
self._log_loss(loss, batch, "val") self._log_loss(loss, batch, "val")
self.log( self.log(
"val/iterations", "val/iterations",
@@ -232,23 +243,6 @@ class GraphSolver(LightningModule):
) )
def test_step(self, batch: Batch, _): def test_step(self, batch: Batch, _):
# x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
# y_pred, _ = self.model(
# x,
# edge_index,
# edge_attr,
# c,
# batch.boundary_mask,
# batch.boundary_values,
# y=None,
# loss_fn=None,
# max_iters=1000,
# plot_results=True,
# batch=batch,
# )
# loss = self._compute_loss(y_pred, y)
# # _plot_mesh(batch.pos, y, y_pred, batch.batch)
# self._log_loss(loss, batch, "test")
x, y, edge_index, edge_attr = self._preprocess_batch(batch) x, y, edge_index, edge_attr = self._preprocess_batch(batch)
deg = self._compute_deg(edge_index, edge_attr, x.size(0)) deg = self._compute_deg(edge_index, edge_attr, x.size(0))
@@ -263,7 +257,7 @@ class GraphSolver(LightningModule):
batch.boundary_values, batch.boundary_values,
) )
converged = self._check_convergence(out, x) converged = self._check_convergence(out, x)
_plot_mesh(batch.pos, y, out, batch.batch, i) # _plot_mesh(batch.pos, y, out, batch.batch, i)
if converged: if converged:
break break
x = out x = out

View File

@@ -17,11 +17,11 @@ class FiniteDifferenceStep(MessagePassing):
spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)),
) )
self.update_net = nn.Sequential( # self.update_net = nn.Sequential(
spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), # spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)),
nn.GELU(), # nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, hidden_dim)), # spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
) # )
self.out_net = nn.Sequential( self.out_net = nn.Sequential(
spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)),
@@ -47,8 +47,9 @@ class FiniteDifferenceStep(MessagePassing):
""" """
TODO: add docstring. TODO: add docstring.
""" """
update_input = torch.cat([x, aggr_out], dim=-1) # update_input = torch.cat([x, aggr_out], dim=-1)
return self.update_net(update_input) # return self.update_net(update_input)
return aggr_out
def aggregate(self, inputs, index, deg): def aggregate(self, inputs, index, deg):
""" """