This commit is contained in:
FilippoOlivo
2025-09-26 09:10:02 +02:00
parent f3be9e99f8
commit c6c416e682
2 changed files with 16 additions and 2 deletions

View File

@@ -50,7 +50,10 @@ class GraphSolver(LightningModule):
for _ in range(self.unrolling_steps):
x_prev = x.detach()
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
loss += self.loss(x, y)
actual_loss = self.loss(x, y)
loss += actual_loss
print(f"Train step loss: {actual_loss.item()}")
self._log_loss(loss, batch, "train")
return loss
@@ -78,3 +81,7 @@ class GraphSolver(LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def scale_bc(self, data: Batch, y: torch.Tensor):
t = data.boundary_temperature[data.batch]
return y * t