update model, model and datamodule

This commit is contained in:
FilippoOlivo
2025-10-14 10:02:39 +02:00
parent b9335cd2f8
commit 720931b831
4 changed files with 16 additions and 48 deletions

View File

@@ -28,20 +28,18 @@ class GraphSolver(LightningModule):
self,
x: torch.Tensor,
c: torch.Tensor,
boundary: torch.Tensor,
boundary_mask: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
unrolling_steps: int = None,
boundary_mask: torch.Tensor = None,
):
return self.model(
x,
c,
boundary,
boundary_mask,
edge_index,
edge_attr,
unrolling_steps,
boundary_mask=boundary_mask,
)
def _compute_loss(self, x, y):
@@ -66,11 +64,10 @@ class GraphSolver(LightningModule):
y_pred = self(
x,
c,
batch.boundary_values,
batch.boundary_mask,
edge_index=edge_index,
edge_attr=edge_attr,
unrolling_steps=self.unrolling_steps,
boundary_mask=batch.boundary_mask,
)
loss = self.loss(y_pred, y)
boundary_loss = self.loss(
@@ -85,8 +82,6 @@ class GraphSolver(LightningModule):
y_pred = self(
x,
c,
batch.boundary_values,
batch.boundary_mask,
edge_index=edge_index,
edge_attr=edge_attr,
unrolling_steps=self.unrolling_steps,
@@ -104,8 +99,6 @@ class GraphSolver(LightningModule):
y_pred = self.model(
x,
c,
batch.boundary_values,
batch.boundary_mask,
edge_index=edge_index,
edge_attr=edge_attr,
unrolling_steps=self.unrolling_steps,