update model, model and datamodule
This commit is contained in:
@@ -28,20 +28,18 @@ class GraphSolver(LightningModule):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
boundary: torch.Tensor,
|
||||
boundary_mask: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
unrolling_steps: int = None,
|
||||
boundary_mask: torch.Tensor = None,
|
||||
):
|
||||
return self.model(
|
||||
x,
|
||||
c,
|
||||
boundary,
|
||||
boundary_mask,
|
||||
edge_index,
|
||||
edge_attr,
|
||||
unrolling_steps,
|
||||
boundary_mask=boundary_mask,
|
||||
)
|
||||
|
||||
def _compute_loss(self, x, y):
|
||||
@@ -66,11 +64,10 @@ class GraphSolver(LightningModule):
|
||||
y_pred = self(
|
||||
x,
|
||||
c,
|
||||
batch.boundary_values,
|
||||
batch.boundary_mask,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
unrolling_steps=self.unrolling_steps,
|
||||
boundary_mask=batch.boundary_mask,
|
||||
)
|
||||
loss = self.loss(y_pred, y)
|
||||
boundary_loss = self.loss(
|
||||
@@ -85,8 +82,6 @@ class GraphSolver(LightningModule):
|
||||
y_pred = self(
|
||||
x,
|
||||
c,
|
||||
batch.boundary_values,
|
||||
batch.boundary_mask,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
unrolling_steps=self.unrolling_steps,
|
||||
@@ -104,8 +99,6 @@ class GraphSolver(LightningModule):
|
||||
y_pred = self.model(
|
||||
x,
|
||||
c,
|
||||
batch.boundary_values,
|
||||
batch.boundary_mask,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
unrolling_steps=self.unrolling_steps,
|
||||
|
||||
Reference in New Issue
Block a user