import torch from torch_geometric.data import Data D_IN_KEYS = "x" D_ATTR_KEYS = ["c", "edge_attr"] D_OUT_KEY = "y" D_KEYS = [D_IN_KEYS] + [D_OUT_KEY] + D_ATTR_KEYS D_BOUNDS_KEYS = "boundary_temperatures" class Normalizer: def __init__(self, data): self.mean, self.std = self._compute_stats(data) def _compute_stats(self, data: list[Data]) -> tuple[dict, dict]: mean = {} std = {} for key in D_KEYS: tmp = torch.empty(0) for d in data: if not hasattr(d, key): raise AttributeError(f"Manca '{key}' in uno dei Data.") if tmp.numel() == 0: tmp = d[key] else: tmp = torch.cat([tmp, d[key]], dim=0) mean[key] = tmp.mean(dim=0, keepdim=True) std[key] = tmp.std(dim=0, keepdim=True) + 1e-6 return mean, std def normalize(self, data): for d in data: for key in D_KEYS: if not hasattr(d, key): raise AttributeError(f"Manca '{key}' in uno dei Data.") d[key] = (d[key] - self.mean[key]) / self.std[key] self._recompute_boundary_temperatures(data) def _recompute_boundary_temperatures(self, data): for d in data: bottom_bc = d.y[d.bottom_boundary_ids].median() top_bc = d.y[d.top_boundary_ids].median() left_bc = d.y[d.left_boundary_ids].median() right_bc = d.y[d.right_boundary_ids].median() boundaries_temperatures = torch.tensor( [bottom_bc, right_bc, top_bc, left_bc], dtype=torch.float32 ) d.boundary_temperatures = boundaries_temperatures.unsqueeze(0) def denormalize(self, y: torch.tensor): return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]