implement new model

This commit is contained in:
Filippo Olivo
2025-10-03 10:42:05 +02:00
parent b1a9cecb42
commit 7a6fbdb89c
3 changed files with 80 additions and 54 deletions

View File

@@ -17,6 +17,7 @@ class GraphDataModule(LightningDataModule):
val_size: float = 0.1,
test_size: float = 0.1,
batch_size: int = 32,
remove_boundary_edges: bool = True,
):
super().__init__()
self.hf_repo = hf_repo
@@ -27,6 +28,7 @@ class GraphDataModule(LightningDataModule):
self.val_size = val_size
self.test_size = test_size
self.batch_size = batch_size
self.remove_boundary_edges = remove_boundary_edges
def prepare_data(self):
hf_dataset = load_dataset(self.hf_repo, name="snapshots")[
@@ -44,6 +46,28 @@ class GraphDataModule(LightningDataModule):
)
]
def _compute_boundary_mask(
self, bottom_ids, right_ids, top_ids, left_ids, temperature
):
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)]
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)]
left_ids = left_ids[~torch.isin(left_ids, top_ids)]
right_ids = right_ids[~torch.isin(right_ids, top_ids)]
bottom_bc = temperature[bottom_ids].median()
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc
left_bc = temperature[left_ids].median()
left_bc_mask = torch.ones(len(left_ids)) * left_bc
right_bc = temperature[right_ids].median()
right_bc_mask = torch.ones(len(right_ids)) * right_bc
boundary_values = torch.cat(
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
)
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
return boundary_mask, boundary_values
def _build_dataset(
self,
snapshot: dict,
@@ -66,27 +90,34 @@ class GraphDataModule(LightningDataModule):
)
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
boundary_mask, boundary_values = self._compute_boundary_mask(
bottom_ids, right_ids, top_ids, left_ids, temperature
)
if self.remove_boundary_edges:
boundary_idx = torch.unique(boundary_mask)
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
edge_index = edge_index[:, edge_index_mask]
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
edge_attr = torch.cat(
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
)
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)]
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)]
left_ids = left_ids[~torch.isin(left_ids, top_ids)]
right_ids = right_ids[~torch.isin(right_ids, top_ids)]
bottom_bc = temperature[bottom_ids].median()
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc
left_bc = temperature[left_ids].median()
left_bc_mask = torch.ones(len(left_ids)) * left_bc
right_bc = temperature[right_ids].median()
right_bc_mask = torch.ones(len(right_ids)) * right_bc
boundary_values = torch.cat(
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
)
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1)
if self.remove_boundary_edges:
x[boundary_mask] = boundary_values.unsqueeze(-1)
return MeshData(
x=x,
c=conductivity.unsqueeze(-1),
edge_index=edge_index,
pos=pos,
edge_attr=edge_attr,
y=temperature.unsqueeze(-1),
boundary_mask=boundary_mask,
boundary_values=torch.tensor(0),
)
return MeshData(
x=torch.rand_like(temperature).unsqueeze(-1),
@@ -110,7 +141,6 @@ class GraphDataModule(LightningDataModule):
if stage == "test" or stage is None:
self.test_data = self.data[val_end:]
# nel tuo LightningDataModule
def train_dataloader(self):
return DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True

View File

@@ -4,10 +4,6 @@ from torch_geometric.nn import MessagePassing
from matplotlib.tri import Triangulation
def _import_boundary_conditions(x, boundary, boundary_mask):
x[boundary_mask] = boundary
def plot_results_fn(x, pos, i, batch):
x = x[batch == 0]
pos = pos[batch == 0]
@@ -63,7 +59,6 @@ class DecX(nn.Module):
class ConditionalGNOBlock(MessagePassing):
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
super().__init__(aggr=aggr, node_dim=0)
# self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
self.edge_attr_net = nn.Sequential(
nn.Linear(edge_ch, hidden_ch // 2),
@@ -73,9 +68,9 @@ class ConditionalGNOBlock(MessagePassing):
)
self.msg_proj = nn.Sequential(
nn.Linear(hidden_ch, hidden_ch),
nn.Linear(hidden_ch, hidden_ch, bias=False),
nn.SiLU(),
nn.Linear(hidden_ch, hidden_ch),
nn.Linear(hidden_ch, hidden_ch, bias=False),
)
self.diff_net = nn.Sequential(
@@ -98,7 +93,15 @@ class ConditionalGNOBlock(MessagePassing):
)
self.balancing = nn.Parameter(torch.tensor(0.0))
self.alpha = nn.Parameter(torch.tensor(1.0))
self.alpha_net = nn.Sequential(
nn.Linear(2 * hidden_ch, hidden_ch),
nn.SiLU(),
nn.Linear(hidden_ch, hidden_ch // 2),
nn.SiLU(),
nn.Linear(hidden_ch // 2, 1),
nn.Sigmoid(),
)
def forward(self, x, c, edge_index, edge_attr=None):
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
@@ -108,14 +111,14 @@ class ConditionalGNOBlock(MessagePassing):
alpha = torch.sigmoid(self.balancing)
gate = torch.sigmoid(self.edge_attr_net(edge_attr))
m = (
alpha * self.diff_net(x_j - x_i)
+ (1 - alpha) * self.x_net(x_j) * gate
)
alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j)
) * gate
m = m * self.c_ij_net(c_ij)
return m
def update(self, aggr_out, x):
return x + self.alpha * self.msg_proj(aggr_out)
alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1))
return x + alpha * self.msg_proj(aggr_out)
class GatingGNO(nn.Module):
@@ -153,14 +156,10 @@ class GatingGNO(nn.Module):
):
x = self.encoder_x(x)
c = self.encoder_c(c)
boundary = self.encoder_x(boundary)
if plot_results:
_import_boundary_conditions(x, boundary, boundary_mask)
x_ = self.dec(x)
plot_results_fn(x_, pos, 0, batch=batch)
for _ in range(1, unrolling_steps + 1):
_import_boundary_conditions(x, boundary, boundary_mask)
for blk in self.blocks:
x = blk(x, c, edge_index, edge_attr=edge_attr)
if plot_results:

View File

@@ -1,31 +1,26 @@
import torch
from lightning import LightningModule
from torch_geometric.data import Batch
from matplotlib.tri import Triangulation
import importlib
# def plot_results(x, pos, step, i, batch):
# x = x[batch == 0]
# pos = pos[batch == 0]
# tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu())
# import matplotlib.pyplot as plt
# plt.tricontourf(tria, x[:, 0].cpu(), levels=14)
# plt.colorbar()
# plt.savefig(f"{step:03d}_out_{i:03d}.png")
# plt.axis("equal")
# plt.close()
def import_class(class_path: str):
module_path, class_name = class_path.rsplit(".", 1) # split last dot
module = importlib.import_module(module_path) # import the module
cls = getattr(module, class_name) # get the class
return cls
class GraphSolver(LightningModule):
def __init__(
self,
model: torch.nn.Module,
model_class_path: str,
model_init_args: dict,
loss: torch.nn.Module = None,
unrolling_steps: int = 48,
):
super().__init__()
self.model = model
self.model = import_class(model_class_path)(**model_init_args)
self.loss = loss if loss is not None else torch.nn.MSELoss()
self.unrolling_steps = unrolling_steps
@@ -57,7 +52,7 @@ class GraphSolver(LightningModule):
def _log_loss(self, loss, batch, stage: str):
self.log(
f"{stage}_loss",
f"{stage}/loss",
loss,
on_step=False,
on_epoch=True,
@@ -68,8 +63,6 @@ class GraphSolver(LightningModule):
def training_step(self, batch: Batch, _):
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
# x = self._impose_bc(x, batch)
# for _ in range(self.unrolling_steps):
y_pred = self(
x,
c,
@@ -79,9 +72,12 @@ class GraphSolver(LightningModule):
edge_attr=edge_attr,
unrolling_steps=self.unrolling_steps,
)
# x = self._impose_bc(x, batch)
loss = self.loss(y_pred, y)
boundary_loss = self.loss(
y_pred[batch.boundary_mask], y[batch.boundary_mask]
)
self._log_loss(loss, batch, "train")
self._log_loss(boundary_loss, batch, "train_boundary")
return loss
def validation_step(self, batch: Batch, _):
@@ -96,12 +92,15 @@ class GraphSolver(LightningModule):
unrolling_steps=self.unrolling_steps,
)
loss = self.loss(y_pred, y)
boundary_loss = self.loss(
y_pred[batch.boundary_mask], y[batch.boundary_mask]
)
self._log_loss(loss, batch, "val")
self._log_loss(boundary_loss, batch, "val_boundary")
return loss
def test_step(self, batch: Batch, _):
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
# for _ in range(self.unrolling_steps):
y_pred = self.model(
x,
c,
@@ -114,8 +113,6 @@ class GraphSolver(LightningModule):
batch=batch.batch,
pos=batch.pos,
)
# x = self._impose_bc(x, batch)
# plot_results(x, batch.pos, self.global_step, _, batch.batch)
loss = self._compute_loss(y_pred, y)
self._log_loss(loss, batch, "test")
return loss