improve unrolling
This commit is contained in:
51
ThermalSolver/normalizer.py
Normal file
51
ThermalSolver/normalizer.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
|
||||
D_IN_KEYS = "x"
|
||||
D_ATTR_KEYS = ["c", "edge_attr"]
|
||||
D_OUT_KEY = "y"
|
||||
D_KEYS = [D_IN_KEYS] + [D_OUT_KEY] + D_ATTR_KEYS
|
||||
D_BOUNDS_KEYS = "boundary_temperatures"
|
||||
|
||||
|
||||
class Normalizer:
|
||||
def __init__(self, data):
|
||||
self.mean, self.std = self._compute_stats(data)
|
||||
|
||||
def _compute_stats(self, data: list[Data]) -> tuple[dict, dict]:
|
||||
mean = {}
|
||||
std = {}
|
||||
for key in D_KEYS:
|
||||
tmp = torch.empty(0)
|
||||
for d in data:
|
||||
if not hasattr(d, key):
|
||||
raise AttributeError(f"Manca '{key}' in uno dei Data.")
|
||||
if tmp.numel() == 0:
|
||||
tmp = d[key]
|
||||
else:
|
||||
tmp = torch.cat([tmp, d[key]], dim=0)
|
||||
mean[key] = tmp.mean(dim=0, keepdim=True)
|
||||
std[key] = tmp.std(dim=0, keepdim=True) + 1e-6
|
||||
return mean, std
|
||||
|
||||
def normalize(self, data):
|
||||
for d in data:
|
||||
for key in D_KEYS:
|
||||
if not hasattr(d, key):
|
||||
raise AttributeError(f"Manca '{key}' in uno dei Data.")
|
||||
d[key] = (d[key] - self.mean[key]) / self.std[key]
|
||||
self._recompute_boundary_temperatures(data)
|
||||
|
||||
def _recompute_boundary_temperatures(self, data):
|
||||
for d in data:
|
||||
bottom_bc = d.y[d.bottom_boundary_ids].median()
|
||||
top_bc = d.y[d.top_boundary_ids].median()
|
||||
left_bc = d.y[d.left_boundary_ids].median()
|
||||
right_bc = d.y[d.right_boundary_ids].median()
|
||||
boundaries_temperatures = torch.tensor(
|
||||
[bottom_bc, right_bc, top_bc, left_bc], dtype=torch.float32
|
||||
)
|
||||
d.boundary_temperatures = boundaries_temperatures.unsqueeze(0)
|
||||
|
||||
def denormalize(self, y: torch.tensor):
|
||||
return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]
|
||||
Reference in New Issue
Block a user