implement ML correction
This commit is contained in:
@@ -4,6 +4,7 @@ from torch_geometric.data import Batch
|
||||
import importlib
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.tri import Triangulation
|
||||
from .model.finite_difference import FiniteDifferenceStep
|
||||
|
||||
|
||||
def import_class(class_path: str):
|
||||
@@ -56,6 +57,7 @@ class GraphSolver(LightningModule):
|
||||
):
|
||||
super().__init__()
|
||||
self.model = import_class(model_class_path)(**model_init_args)
|
||||
self.fd_net = FiniteDifferenceStep()
|
||||
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
||||
self.curriculum_learning = curriculum_learning
|
||||
self.start_iters = start_iters
|
||||
@@ -67,6 +69,8 @@ class GraphSolver(LightningModule):
|
||||
self.automatic_optimization = False
|
||||
self.threshold = 1e-5
|
||||
|
||||
self.aplha = 0.1
|
||||
|
||||
def _compute_deg(self, edge_index, edge_attr, num_nodes):
|
||||
deg = torch.zeros(num_nodes, device=edge_index.device)
|
||||
deg = deg.scatter_add(0, edge_index[1], edge_attr)
|
||||
@@ -96,8 +100,15 @@ class GraphSolver(LightningModule):
|
||||
def _compute_model_steps(
|
||||
self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values
|
||||
):
|
||||
out = self.model(x, edge_index, edge_attr, deg)
|
||||
with torch.no_grad():
|
||||
out = self.fd_net(x, edge_index, edge_attr, deg)
|
||||
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
# diff = out - x
|
||||
correction = self.model(x, edge_index, edge_attr, deg)
|
||||
out = out + self.aplha * correction
|
||||
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
# out = self.model(x, edge_index, edge_attr, deg)
|
||||
# out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
return out
|
||||
|
||||
def _check_convergence(self, out, x):
|
||||
@@ -132,11 +143,7 @@ class GraphSolver(LightningModule):
|
||||
deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
||||
losses = []
|
||||
acc_loss, acc_it = 0, 0
|
||||
max_acc_iters = (
|
||||
self.current_iters // self.accumulation_iters + 1
|
||||
if self.accumulation_iters is not None
|
||||
else 1
|
||||
)
|
||||
|
||||
for i in range(self.current_iters):
|
||||
out = self._compute_model_steps(
|
||||
x,
|
||||
|
||||
Reference in New Issue
Block a user