44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import torch
|
|
from torch_geometric.data import Data
|
|
|
|
D_IN_KEYS = "x"
|
|
D_ATTR_KEYS = ["c", "edge_attr"]
|
|
D_OUT_KEY = "y"
|
|
D_KEYS = D_ATTR_KEYS + [D_OUT_KEY]
|
|
|
|
|
|
class Normalizer:
|
|
def __init__(self, data):
|
|
self.mean, self.std = self._compute_stats(data)
|
|
|
|
def _compute_stats(self, data: list[Data]) -> tuple[dict, dict]:
|
|
mean = {}
|
|
std = {}
|
|
for key in D_KEYS:
|
|
tmp = torch.empty(0)
|
|
for d in data:
|
|
if not hasattr(d, key):
|
|
raise AttributeError(f"Manca '{key}' in uno dei Data.")
|
|
if tmp.numel() == 0:
|
|
tmp = d[key]
|
|
else:
|
|
tmp = torch.cat([tmp, d[key]], dim=0)
|
|
mean[key] = tmp.mean(dim=0, keepdim=True)
|
|
std[key] = tmp.std(dim=0, keepdim=True) + 1e-6
|
|
return mean, std
|
|
|
|
@staticmethod
|
|
def _apply_input_boundary(data: Data):
|
|
bc = data.y[data.boundary_mask]
|
|
data[D_IN_KEYS][data.boundary_mask] = bc
|
|
|
|
def normalize(self, data: list[Data]):
|
|
for d in data:
|
|
for key in D_KEYS:
|
|
d[key] = (d[key] - self.mean[key]) / self.std[key]
|
|
self._apply_input_boundary(d)
|
|
return data
|
|
|
|
def denormalize(self, y: torch.tensor):
|
|
return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]
|