implement ML correction

This commit is contained in:
Filippo Olivo
2025-11-18 21:55:54 +01:00
parent 1c7b593762
commit d865556c9f
3 changed files with 64 additions and 135 deletions

View File

@@ -4,6 +4,7 @@ from torch_geometric.data import Batch
import importlib
from matplotlib import pyplot as plt
from matplotlib.tri import Triangulation
from .model.finite_difference import FiniteDifferenceStep
def import_class(class_path: str):
@@ -56,6 +57,7 @@ class GraphSolver(LightningModule):
):
super().__init__()
self.model = import_class(model_class_path)(**model_init_args)
self.fd_net = FiniteDifferenceStep()
self.loss = loss if loss is not None else torch.nn.MSELoss()
self.curriculum_learning = curriculum_learning
self.start_iters = start_iters
@@ -67,6 +69,8 @@ class GraphSolver(LightningModule):
self.automatic_optimization = False
self.threshold = 1e-5
self.aplha = 0.1
def _compute_deg(self, edge_index, edge_attr, num_nodes):
deg = torch.zeros(num_nodes, device=edge_index.device)
deg = deg.scatter_add(0, edge_index[1], edge_attr)
@@ -96,8 +100,15 @@ class GraphSolver(LightningModule):
def _compute_model_steps(
self, x, edge_index, edge_attr, deg, boundary_mask, boundary_values
):
out = self.model(x, edge_index, edge_attr, deg)
with torch.no_grad():
out = self.fd_net(x, edge_index, edge_attr, deg)
out[boundary_mask] = boundary_values.unsqueeze(-1)
# diff = out - x
correction = self.model(x, edge_index, edge_attr, deg)
out = out + self.aplha * correction
out[boundary_mask] = boundary_values.unsqueeze(-1)
# out = self.model(x, edge_index, edge_attr, deg)
# out[boundary_mask] = boundary_values.unsqueeze(-1)
return out
def _check_convergence(self, out, x):
@@ -132,11 +143,7 @@ class GraphSolver(LightningModule):
deg = self._compute_deg(edge_index, edge_attr, x.size(0))
losses = []
acc_loss, acc_it = 0, 0
max_acc_iters = (
self.current_iters // self.accumulation_iters + 1
if self.accumulation_iters is not None
else 1
)
for i in range(self.current_iters):
out = self._compute_model_steps(
x,