update model, model and datamodule

This commit is contained in:
FilippoOlivo
2025-10-14 10:02:39 +02:00
parent b9335cd2f8
commit 720931b831
4 changed files with 16 additions and 48 deletions

View File

@@ -4,8 +4,7 @@ from torch_geometric.data import Data
D_IN_KEYS = "x"
D_ATTR_KEYS = ["c", "edge_attr"]
D_OUT_KEY = "y"
D_KEYS = [D_IN_KEYS] + [D_OUT_KEY] + D_ATTR_KEYS
D_BOUNDS_KEYS = "boundary_temperatures"
D_KEYS = D_ATTR_KEYS + [D_OUT_KEY]
class Normalizer:
@@ -28,24 +27,17 @@ class Normalizer:
std[key] = tmp.std(dim=0, keepdim=True) + 1e-6
return mean, std
def normalize(self, data):
@staticmethod
def _apply_input_boundary(data: Data):
bc = data.y[data.boundary_mask]
data[D_IN_KEYS][data.boundary_mask] = bc
def normalize(self, data: list[Data]):
for d in data:
for key in D_KEYS:
if not hasattr(d, key):
raise AttributeError(f"Manca '{key}' in uno dei Data.")
d[key] = (d[key] - self.mean[key]) / self.std[key]
self._recompute_boundary_temperatures(data)
def _recompute_boundary_temperatures(self, data):
for d in data:
bottom_bc = d.y[d.bottom_boundary_ids].median()
top_bc = d.y[d.top_boundary_ids].median()
left_bc = d.y[d.left_boundary_ids].median()
right_bc = d.y[d.right_boundary_ids].median()
boundaries_temperatures = torch.tensor(
[bottom_bc, right_bc, top_bc, left_bc], dtype=torch.float32
)
d.boundary_temperatures = boundaries_temperatures.unsqueeze(0)
self._apply_input_boundary(d)
return data
def denormalize(self, y: torch.tensor):
return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]