implement new model

This commit is contained in:
Filippo Olivo
2025-10-03 10:42:05 +02:00
parent b1a9cecb42
commit 7a6fbdb89c
3 changed files with 80 additions and 54 deletions

View File

@@ -17,6 +17,7 @@ class GraphDataModule(LightningDataModule):
val_size: float = 0.1, val_size: float = 0.1,
test_size: float = 0.1, test_size: float = 0.1,
batch_size: int = 32, batch_size: int = 32,
remove_boundary_edges: bool = True,
): ):
super().__init__() super().__init__()
self.hf_repo = hf_repo self.hf_repo = hf_repo
@@ -27,6 +28,7 @@ class GraphDataModule(LightningDataModule):
self.val_size = val_size self.val_size = val_size
self.test_size = test_size self.test_size = test_size
self.batch_size = batch_size self.batch_size = batch_size
self.remove_boundary_edges = remove_boundary_edges
def prepare_data(self): def prepare_data(self):
hf_dataset = load_dataset(self.hf_repo, name="snapshots")[ hf_dataset = load_dataset(self.hf_repo, name="snapshots")[
@@ -44,6 +46,28 @@ class GraphDataModule(LightningDataModule):
) )
] ]
def _compute_boundary_mask(
self, bottom_ids, right_ids, top_ids, left_ids, temperature
):
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)]
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)]
left_ids = left_ids[~torch.isin(left_ids, top_ids)]
right_ids = right_ids[~torch.isin(right_ids, top_ids)]
bottom_bc = temperature[bottom_ids].median()
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc
left_bc = temperature[left_ids].median()
left_bc_mask = torch.ones(len(left_ids)) * left_bc
right_bc = temperature[right_ids].median()
right_bc_mask = torch.ones(len(right_ids)) * right_bc
boundary_values = torch.cat(
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
)
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
return boundary_mask, boundary_values
def _build_dataset( def _build_dataset(
self, self,
snapshot: dict, snapshot: dict,
@@ -66,27 +90,34 @@ class GraphDataModule(LightningDataModule):
) )
edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
boundary_mask, boundary_values = self._compute_boundary_mask(
bottom_ids, right_ids, top_ids, left_ids, temperature
)
if self.remove_boundary_edges:
boundary_idx = torch.unique(boundary_mask)
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
edge_index = edge_index[:, edge_index_mask]
edge_attr = pos[edge_index[0]] - pos[edge_index[1]] edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
edge_attr = torch.cat( edge_attr = torch.cat(
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
) )
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)] x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1)
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)] if self.remove_boundary_edges:
left_ids = left_ids[~torch.isin(left_ids, top_ids)] x[boundary_mask] = boundary_values.unsqueeze(-1)
right_ids = right_ids[~torch.isin(right_ids, top_ids)] return MeshData(
x=x,
bottom_bc = temperature[bottom_ids].median() c=conductivity.unsqueeze(-1),
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc edge_index=edge_index,
left_bc = temperature[left_ids].median() pos=pos,
left_bc_mask = torch.ones(len(left_ids)) * left_bc edge_attr=edge_attr,
right_bc = temperature[right_ids].median() y=temperature.unsqueeze(-1),
right_bc_mask = torch.ones(len(right_ids)) * right_bc boundary_mask=boundary_mask,
boundary_values=torch.tensor(0),
boundary_values = torch.cat( )
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
)
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
return MeshData( return MeshData(
x=torch.rand_like(temperature).unsqueeze(-1), x=torch.rand_like(temperature).unsqueeze(-1),
@@ -110,7 +141,6 @@ class GraphDataModule(LightningDataModule):
if stage == "test" or stage is None: if stage == "test" or stage is None:
self.test_data = self.data[val_end:] self.test_data = self.data[val_end:]
# nel tuo LightningDataModule
def train_dataloader(self): def train_dataloader(self):
return DataLoader( return DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True self.train_data, batch_size=self.batch_size, shuffle=True

View File

@@ -4,10 +4,6 @@ from torch_geometric.nn import MessagePassing
from matplotlib.tri import Triangulation from matplotlib.tri import Triangulation
def _import_boundary_conditions(x, boundary, boundary_mask):
x[boundary_mask] = boundary
def plot_results_fn(x, pos, i, batch): def plot_results_fn(x, pos, i, batch):
x = x[batch == 0] x = x[batch == 0]
pos = pos[batch == 0] pos = pos[batch == 0]
@@ -63,7 +59,6 @@ class DecX(nn.Module):
class ConditionalGNOBlock(MessagePassing): class ConditionalGNOBlock(MessagePassing):
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"): def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
super().__init__(aggr=aggr, node_dim=0) super().__init__(aggr=aggr, node_dim=0)
# self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
self.edge_attr_net = nn.Sequential( self.edge_attr_net = nn.Sequential(
nn.Linear(edge_ch, hidden_ch // 2), nn.Linear(edge_ch, hidden_ch // 2),
@@ -73,9 +68,9 @@ class ConditionalGNOBlock(MessagePassing):
) )
self.msg_proj = nn.Sequential( self.msg_proj = nn.Sequential(
nn.Linear(hidden_ch, hidden_ch), nn.Linear(hidden_ch, hidden_ch, bias=False),
nn.SiLU(), nn.SiLU(),
nn.Linear(hidden_ch, hidden_ch), nn.Linear(hidden_ch, hidden_ch, bias=False),
) )
self.diff_net = nn.Sequential( self.diff_net = nn.Sequential(
@@ -98,7 +93,15 @@ class ConditionalGNOBlock(MessagePassing):
) )
self.balancing = nn.Parameter(torch.tensor(0.0)) self.balancing = nn.Parameter(torch.tensor(0.0))
self.alpha = nn.Parameter(torch.tensor(1.0))
self.alpha_net = nn.Sequential(
nn.Linear(2 * hidden_ch, hidden_ch),
nn.SiLU(),
nn.Linear(hidden_ch, hidden_ch // 2),
nn.SiLU(),
nn.Linear(hidden_ch // 2, 1),
nn.Sigmoid(),
)
def forward(self, x, c, edge_index, edge_attr=None): def forward(self, x, c, edge_index, edge_attr=None):
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr) return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
@@ -108,14 +111,14 @@ class ConditionalGNOBlock(MessagePassing):
alpha = torch.sigmoid(self.balancing) alpha = torch.sigmoid(self.balancing)
gate = torch.sigmoid(self.edge_attr_net(edge_attr)) gate = torch.sigmoid(self.edge_attr_net(edge_attr))
m = ( m = (
alpha * self.diff_net(x_j - x_i) alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j)
+ (1 - alpha) * self.x_net(x_j) * gate ) * gate
)
m = m * self.c_ij_net(c_ij) m = m * self.c_ij_net(c_ij)
return m return m
def update(self, aggr_out, x): def update(self, aggr_out, x):
return x + self.alpha * self.msg_proj(aggr_out) alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1))
return x + alpha * self.msg_proj(aggr_out)
class GatingGNO(nn.Module): class GatingGNO(nn.Module):
@@ -153,14 +156,10 @@ class GatingGNO(nn.Module):
): ):
x = self.encoder_x(x) x = self.encoder_x(x)
c = self.encoder_c(c) c = self.encoder_c(c)
boundary = self.encoder_x(boundary)
if plot_results: if plot_results:
_import_boundary_conditions(x, boundary, boundary_mask)
x_ = self.dec(x) x_ = self.dec(x)
plot_results_fn(x_, pos, 0, batch=batch) plot_results_fn(x_, pos, 0, batch=batch)
for _ in range(1, unrolling_steps + 1): for _ in range(1, unrolling_steps + 1):
_import_boundary_conditions(x, boundary, boundary_mask)
for blk in self.blocks: for blk in self.blocks:
x = blk(x, c, edge_index, edge_attr=edge_attr) x = blk(x, c, edge_index, edge_attr=edge_attr)
if plot_results: if plot_results:

View File

@@ -1,31 +1,26 @@
import torch import torch
from lightning import LightningModule from lightning import LightningModule
from torch_geometric.data import Batch from torch_geometric.data import Batch
from matplotlib.tri import Triangulation import importlib
# def plot_results(x, pos, step, i, batch): def import_class(class_path: str):
# x = x[batch == 0] module_path, class_name = class_path.rsplit(".", 1) # split last dot
# pos = pos[batch == 0] module = importlib.import_module(module_path) # import the module
# tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu()) cls = getattr(module, class_name) # get the class
# import matplotlib.pyplot as plt return cls
# plt.tricontourf(tria, x[:, 0].cpu(), levels=14)
# plt.colorbar()
# plt.savefig(f"{step:03d}_out_{i:03d}.png")
# plt.axis("equal")
# plt.close()
class GraphSolver(LightningModule): class GraphSolver(LightningModule):
def __init__( def __init__(
self, self,
model: torch.nn.Module, model_class_path: str,
model_init_args: dict,
loss: torch.nn.Module = None, loss: torch.nn.Module = None,
unrolling_steps: int = 48, unrolling_steps: int = 48,
): ):
super().__init__() super().__init__()
self.model = model self.model = import_class(model_class_path)(**model_init_args)
self.loss = loss if loss is not None else torch.nn.MSELoss() self.loss = loss if loss is not None else torch.nn.MSELoss()
self.unrolling_steps = unrolling_steps self.unrolling_steps = unrolling_steps
@@ -57,7 +52,7 @@ class GraphSolver(LightningModule):
def _log_loss(self, loss, batch, stage: str): def _log_loss(self, loss, batch, stage: str):
self.log( self.log(
f"{stage}_loss", f"{stage}/loss",
loss, loss,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
@@ -68,8 +63,6 @@ class GraphSolver(LightningModule):
def training_step(self, batch: Batch, _): def training_step(self, batch: Batch, _):
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
# x = self._impose_bc(x, batch)
# for _ in range(self.unrolling_steps):
y_pred = self( y_pred = self(
x, x,
c, c,
@@ -79,9 +72,12 @@ class GraphSolver(LightningModule):
edge_attr=edge_attr, edge_attr=edge_attr,
unrolling_steps=self.unrolling_steps, unrolling_steps=self.unrolling_steps,
) )
# x = self._impose_bc(x, batch)
loss = self.loss(y_pred, y) loss = self.loss(y_pred, y)
boundary_loss = self.loss(
y_pred[batch.boundary_mask], y[batch.boundary_mask]
)
self._log_loss(loss, batch, "train") self._log_loss(loss, batch, "train")
self._log_loss(boundary_loss, batch, "train_boundary")
return loss return loss
def validation_step(self, batch: Batch, _): def validation_step(self, batch: Batch, _):
@@ -96,12 +92,15 @@ class GraphSolver(LightningModule):
unrolling_steps=self.unrolling_steps, unrolling_steps=self.unrolling_steps,
) )
loss = self.loss(y_pred, y) loss = self.loss(y_pred, y)
boundary_loss = self.loss(
y_pred[batch.boundary_mask], y[batch.boundary_mask]
)
self._log_loss(loss, batch, "val") self._log_loss(loss, batch, "val")
self._log_loss(boundary_loss, batch, "val_boundary")
return loss return loss
def test_step(self, batch: Batch, _): def test_step(self, batch: Batch, _):
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch) x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
# for _ in range(self.unrolling_steps):
y_pred = self.model( y_pred = self.model(
x, x,
c, c,
@@ -114,8 +113,6 @@ class GraphSolver(LightningModule):
batch=batch.batch, batch=batch.batch,
pos=batch.pos, pos=batch.pos,
) )
# x = self._impose_bc(x, batch)
# plot_results(x, batch.pos, self.global_step, _, batch.batch)
loss = self._compute_loss(y_pred, y) loss = self._compute_loss(y_pred, y)
self._log_loss(loss, batch, "test") self._log_loss(loss, batch, "test")
return loss return loss