import torch from torch_geometric.data import Data D_IN_KEYS = "x" D_ATTR_KEYS = ["c", "edge_attr"] D_OUT_KEY = "y" D_KEYS = D_ATTR_KEYS + [D_OUT_KEY] class Normalizer: def __init__(self, data): self.mean, self.std = self._compute_stats(data) def _compute_stats(self, data: list[Data]) -> tuple[dict, dict]: mean = {} std = {} for key in D_KEYS: tmp = torch.empty(0) for d in data: if not hasattr(d, key): raise AttributeError(f"Manca '{key}' in uno dei Data.") if tmp.numel() == 0: tmp = d[key] else: tmp = torch.cat([tmp, d[key]], dim=0) mean[key] = tmp.mean(dim=0, keepdim=True) std[key] = tmp.std(dim=0, keepdim=True) + 1e-6 return mean, std @staticmethod def _apply_input_boundary(data: Data): bc = data.y[data.boundary_mask] data[D_IN_KEYS][data.boundary_mask] = bc def normalize(self, data: list[Data]): for d in data: for key in D_KEYS: d[key] = (d[key] - self.mean[key]) / self.std[key] self._apply_input_boundary(d) return data def denormalize(self, y: torch.tensor): return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]