implement new model
This commit is contained in:
@@ -17,6 +17,7 @@ class GraphDataModule(LightningDataModule):
|
||||
val_size: float = 0.1,
|
||||
test_size: float = 0.1,
|
||||
batch_size: int = 32,
|
||||
remove_boundary_edges: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.hf_repo = hf_repo
|
||||
@@ -27,6 +28,7 @@ class GraphDataModule(LightningDataModule):
|
||||
self.val_size = val_size
|
||||
self.test_size = test_size
|
||||
self.batch_size = batch_size
|
||||
self.remove_boundary_edges = remove_boundary_edges
|
||||
|
||||
def prepare_data(self):
|
||||
hf_dataset = load_dataset(self.hf_repo, name="snapshots")[
|
||||
@@ -44,6 +46,28 @@ class GraphDataModule(LightningDataModule):
|
||||
)
|
||||
]
|
||||
|
||||
def _compute_boundary_mask(
|
||||
self, bottom_ids, right_ids, top_ids, left_ids, temperature
|
||||
):
|
||||
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)]
|
||||
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)]
|
||||
left_ids = left_ids[~torch.isin(left_ids, top_ids)]
|
||||
right_ids = right_ids[~torch.isin(right_ids, top_ids)]
|
||||
|
||||
bottom_bc = temperature[bottom_ids].median()
|
||||
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc
|
||||
left_bc = temperature[left_ids].median()
|
||||
left_bc_mask = torch.ones(len(left_ids)) * left_bc
|
||||
right_bc = temperature[right_ids].median()
|
||||
right_bc_mask = torch.ones(len(right_ids)) * right_bc
|
||||
|
||||
boundary_values = torch.cat(
|
||||
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
|
||||
)
|
||||
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
|
||||
|
||||
return boundary_mask, boundary_values
|
||||
|
||||
def _build_dataset(
|
||||
self,
|
||||
snapshot: dict,
|
||||
@@ -66,27 +90,34 @@ class GraphDataModule(LightningDataModule):
|
||||
)
|
||||
|
||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||
|
||||
boundary_mask, boundary_values = self._compute_boundary_mask(
|
||||
bottom_ids, right_ids, top_ids, left_ids, temperature
|
||||
)
|
||||
|
||||
if self.remove_boundary_edges:
|
||||
boundary_idx = torch.unique(boundary_mask)
|
||||
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
||||
edge_index = edge_index[:, edge_index_mask]
|
||||
|
||||
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
||||
edge_attr = torch.cat(
|
||||
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
||||
)
|
||||
|
||||
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)]
|
||||
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)]
|
||||
left_ids = left_ids[~torch.isin(left_ids, top_ids)]
|
||||
right_ids = right_ids[~torch.isin(right_ids, top_ids)]
|
||||
|
||||
bottom_bc = temperature[bottom_ids].median()
|
||||
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc
|
||||
left_bc = temperature[left_ids].median()
|
||||
left_bc_mask = torch.ones(len(left_ids)) * left_bc
|
||||
right_bc = temperature[right_ids].median()
|
||||
right_bc_mask = torch.ones(len(right_ids)) * right_bc
|
||||
|
||||
boundary_values = torch.cat(
|
||||
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
|
||||
)
|
||||
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
|
||||
x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1)
|
||||
if self.remove_boundary_edges:
|
||||
x[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
return MeshData(
|
||||
x=x,
|
||||
c=conductivity.unsqueeze(-1),
|
||||
edge_index=edge_index,
|
||||
pos=pos,
|
||||
edge_attr=edge_attr,
|
||||
y=temperature.unsqueeze(-1),
|
||||
boundary_mask=boundary_mask,
|
||||
boundary_values=torch.tensor(0),
|
||||
)
|
||||
|
||||
return MeshData(
|
||||
x=torch.rand_like(temperature).unsqueeze(-1),
|
||||
@@ -110,7 +141,6 @@ class GraphDataModule(LightningDataModule):
|
||||
if stage == "test" or stage is None:
|
||||
self.test_data = self.data[val_end:]
|
||||
|
||||
# nel tuo LightningDataModule
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_data, batch_size=self.batch_size, shuffle=True
|
||||
|
||||
@@ -4,10 +4,6 @@ from torch_geometric.nn import MessagePassing
|
||||
from matplotlib.tri import Triangulation
|
||||
|
||||
|
||||
def _import_boundary_conditions(x, boundary, boundary_mask):
|
||||
x[boundary_mask] = boundary
|
||||
|
||||
|
||||
def plot_results_fn(x, pos, i, batch):
|
||||
x = x[batch == 0]
|
||||
pos = pos[batch == 0]
|
||||
@@ -63,7 +59,6 @@ class DecX(nn.Module):
|
||||
class ConditionalGNOBlock(MessagePassing):
|
||||
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
|
||||
super().__init__(aggr=aggr, node_dim=0)
|
||||
# self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
|
||||
|
||||
self.edge_attr_net = nn.Sequential(
|
||||
nn.Linear(edge_ch, hidden_ch // 2),
|
||||
@@ -73,9 +68,9 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
)
|
||||
|
||||
self.msg_proj = nn.Sequential(
|
||||
nn.Linear(hidden_ch, hidden_ch),
|
||||
nn.Linear(hidden_ch, hidden_ch, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch, hidden_ch),
|
||||
nn.Linear(hidden_ch, hidden_ch, bias=False),
|
||||
)
|
||||
|
||||
self.diff_net = nn.Sequential(
|
||||
@@ -98,7 +93,15 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
)
|
||||
|
||||
self.balancing = nn.Parameter(torch.tensor(0.0))
|
||||
self.alpha = nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
self.alpha_net = nn.Sequential(
|
||||
nn.Linear(2 * hidden_ch, hidden_ch),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch, hidden_ch // 2),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_ch // 2, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x, c, edge_index, edge_attr=None):
|
||||
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
||||
@@ -108,14 +111,14 @@ class ConditionalGNOBlock(MessagePassing):
|
||||
alpha = torch.sigmoid(self.balancing)
|
||||
gate = torch.sigmoid(self.edge_attr_net(edge_attr))
|
||||
m = (
|
||||
alpha * self.diff_net(x_j - x_i)
|
||||
+ (1 - alpha) * self.x_net(x_j) * gate
|
||||
)
|
||||
alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j)
|
||||
) * gate
|
||||
m = m * self.c_ij_net(c_ij)
|
||||
return m
|
||||
|
||||
def update(self, aggr_out, x):
|
||||
return x + self.alpha * self.msg_proj(aggr_out)
|
||||
alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1))
|
||||
return x + alpha * self.msg_proj(aggr_out)
|
||||
|
||||
|
||||
class GatingGNO(nn.Module):
|
||||
@@ -153,14 +156,10 @@ class GatingGNO(nn.Module):
|
||||
):
|
||||
x = self.encoder_x(x)
|
||||
c = self.encoder_c(c)
|
||||
boundary = self.encoder_x(boundary)
|
||||
if plot_results:
|
||||
_import_boundary_conditions(x, boundary, boundary_mask)
|
||||
x_ = self.dec(x)
|
||||
plot_results_fn(x_, pos, 0, batch=batch)
|
||||
|
||||
for _ in range(1, unrolling_steps + 1):
|
||||
_import_boundary_conditions(x, boundary, boundary_mask)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||
if plot_results:
|
||||
|
||||
@@ -1,31 +1,26 @@
|
||||
import torch
|
||||
from lightning import LightningModule
|
||||
from torch_geometric.data import Batch
|
||||
from matplotlib.tri import Triangulation
|
||||
import importlib
|
||||
|
||||
|
||||
# def plot_results(x, pos, step, i, batch):
|
||||
# x = x[batch == 0]
|
||||
# pos = pos[batch == 0]
|
||||
# tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu())
|
||||
# import matplotlib.pyplot as plt
|
||||
|
||||
# plt.tricontourf(tria, x[:, 0].cpu(), levels=14)
|
||||
# plt.colorbar()
|
||||
# plt.savefig(f"{step:03d}_out_{i:03d}.png")
|
||||
# plt.axis("equal")
|
||||
# plt.close()
|
||||
def import_class(class_path: str):
|
||||
module_path, class_name = class_path.rsplit(".", 1) # split last dot
|
||||
module = importlib.import_module(module_path) # import the module
|
||||
cls = getattr(module, class_name) # get the class
|
||||
return cls
|
||||
|
||||
|
||||
class GraphSolver(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
model_class_path: str,
|
||||
model_init_args: dict,
|
||||
loss: torch.nn.Module = None,
|
||||
unrolling_steps: int = 48,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model = import_class(model_class_path)(**model_init_args)
|
||||
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
||||
self.unrolling_steps = unrolling_steps
|
||||
|
||||
@@ -57,7 +52,7 @@ class GraphSolver(LightningModule):
|
||||
|
||||
def _log_loss(self, loss, batch, stage: str):
|
||||
self.log(
|
||||
f"{stage}_loss",
|
||||
f"{stage}/loss",
|
||||
loss,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
@@ -68,8 +63,6 @@ class GraphSolver(LightningModule):
|
||||
|
||||
def training_step(self, batch: Batch, _):
|
||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||
# x = self._impose_bc(x, batch)
|
||||
# for _ in range(self.unrolling_steps):
|
||||
y_pred = self(
|
||||
x,
|
||||
c,
|
||||
@@ -79,9 +72,12 @@ class GraphSolver(LightningModule):
|
||||
edge_attr=edge_attr,
|
||||
unrolling_steps=self.unrolling_steps,
|
||||
)
|
||||
# x = self._impose_bc(x, batch)
|
||||
loss = self.loss(y_pred, y)
|
||||
boundary_loss = self.loss(
|
||||
y_pred[batch.boundary_mask], y[batch.boundary_mask]
|
||||
)
|
||||
self._log_loss(loss, batch, "train")
|
||||
self._log_loss(boundary_loss, batch, "train_boundary")
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Batch, _):
|
||||
@@ -96,12 +92,15 @@ class GraphSolver(LightningModule):
|
||||
unrolling_steps=self.unrolling_steps,
|
||||
)
|
||||
loss = self.loss(y_pred, y)
|
||||
boundary_loss = self.loss(
|
||||
y_pred[batch.boundary_mask], y[batch.boundary_mask]
|
||||
)
|
||||
self._log_loss(loss, batch, "val")
|
||||
self._log_loss(boundary_loss, batch, "val_boundary")
|
||||
return loss
|
||||
|
||||
def test_step(self, batch: Batch, _):
|
||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||
# for _ in range(self.unrolling_steps):
|
||||
y_pred = self.model(
|
||||
x,
|
||||
c,
|
||||
@@ -114,8 +113,6 @@ class GraphSolver(LightningModule):
|
||||
batch=batch.batch,
|
||||
pos=batch.pos,
|
||||
)
|
||||
# x = self._impose_bc(x, batch)
|
||||
# plot_results(x, batch.pos, self.global_step, _, batch.batch)
|
||||
loss = self._compute_loss(y_pred, y)
|
||||
self._log_loss(loss, batch, "test")
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user