diff --git a/ThermalSolver/model/__init__.py b/ThermalSolver/model/__init__.py new file mode 100644 index 0000000..cf8c596 --- /dev/null +++ b/ThermalSolver/model/__init__.py @@ -0,0 +1,5 @@ +__all__ = ["GraphFiniteDifference", "GatingGNO"] + +from .finite_difference import GraphFiniteDifference +from .local_gno import GatingGNO +from .point_net import PointNet diff --git a/ThermalSolver/model/finite_difference.py b/ThermalSolver/model/finite_difference.py new file mode 100644 index 0000000..28bee77 --- /dev/null +++ b/ThermalSolver/model/finite_difference.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +from torch_geometric.nn import MessagePassing +from tqdm import tqdm + + +class FiniteDifferenceStep(MessagePassing): + """ + TODO: add docstring. + """ + + def __init__( + self, + aggr: str = "add", + normalize: bool = True, + root_weight: float = 1.0, + ): + super().__init__(aggr=aggr) + + self.normalize = normalize + assert ( + aggr == "add" + ), "Per somme pesate, l'aggregazione deve essere 'add'." + self.root_weight = float(root_weight) + + def forward(self, x, edge_index, edge_weight, deg): + """ + TODO: add docstring. + """ + out = self.propagate(edge_index, x=x, edge_weight=edge_weight, deg=deg) + return out + + def message(self, x_j, edge_weight): + """ + TODO: add docstring. + """ + return edge_weight.view(-1, 1) * x_j + + def aggregate(self, inputs, index, deg): + """ + TODO: add docstring. + """ + out = super().aggregate(inputs, index) + deg = deg + 1e-7 + return out / deg.view(-1, 1) + + def update(self, aggr_out, x): + """ + TODO: add docstring. + """ + return self.root_weight * aggr_out + (1 - self.root_weight) * x + + +class GraphFiniteDifference(nn.Module): + """ + TODO: add docstring. + """ + + def __init__(self, max_iters: int = 1000, threshold: float = 1e-4): + """ + TODO: add docstring. + """ + super().__init__() + self.max_iters = max_iters + self.threshold = threshold + self.fd_step = FiniteDifferenceStep( + aggr="add", normalize=True, root_weight=1.0 + ) + + @staticmethod + def _compute_deg(edge_index, edge_weight, num_nodes): + """ + TODO: add docstring. + """ + deg = torch.zeros(num_nodes, device=edge_index.device) + deg = deg.scatter_add(0, edge_index[1], edge_weight) + return deg + 1e-7 + + @staticmethod + def _compute_c_ij(c, edge_index): + """ + TODO: add docstring. + """ + return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze() + + def forward( + self, x, edge_index, edge_weight, c, boundary_mask, boundary_values + ): + """ + TODO: add docstring. + """ + c_ij = self._compute_c_ij(c, edge_index) + edge_weight = edge_weight * c_ij + deg = self._compute_deg(edge_index, edge_weight, x.size(0)) + conv_thres = self.threshold * torch.norm(x) + for _i in tqdm(range(self.max_iters)): + out = self.fd_step(x, edge_index, edge_weight, deg) + out[boundary_mask] = boundary_values.unsqueeze(-1) + if torch.norm(out - x) < conv_thres: + break + x = out + return out