minor
This commit is contained in:
@@ -50,7 +50,10 @@ class GraphSolver(LightningModule):
|
||||
for _ in range(self.unrolling_steps):
|
||||
x_prev = x.detach()
|
||||
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
|
||||
loss += self.loss(x, y)
|
||||
actual_loss = self.loss(x, y)
|
||||
loss += actual_loss
|
||||
print(f"Train step loss: {actual_loss.item()}")
|
||||
|
||||
self._log_loss(loss, batch, "train")
|
||||
return loss
|
||||
|
||||
@@ -78,3 +81,7 @@ class GraphSolver(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
def scale_bc(self, data: Batch, y: torch.Tensor):
|
||||
t = data.boundary_temperature[data.batch]
|
||||
return y * t
|
||||
|
||||
Reference in New Issue
Block a user