implement new model
This commit is contained in:
@@ -17,6 +17,7 @@ class GraphDataModule(LightningDataModule):
|
|||||||
val_size: float = 0.1,
|
val_size: float = 0.1,
|
||||||
test_size: float = 0.1,
|
test_size: float = 0.1,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
|
remove_boundary_edges: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hf_repo = hf_repo
|
self.hf_repo = hf_repo
|
||||||
@@ -27,6 +28,7 @@ class GraphDataModule(LightningDataModule):
|
|||||||
self.val_size = val_size
|
self.val_size = val_size
|
||||||
self.test_size = test_size
|
self.test_size = test_size
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.remove_boundary_edges = remove_boundary_edges
|
||||||
|
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
hf_dataset = load_dataset(self.hf_repo, name="snapshots")[
|
hf_dataset = load_dataset(self.hf_repo, name="snapshots")[
|
||||||
@@ -44,6 +46,28 @@ class GraphDataModule(LightningDataModule):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _compute_boundary_mask(
|
||||||
|
self, bottom_ids, right_ids, top_ids, left_ids, temperature
|
||||||
|
):
|
||||||
|
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)]
|
||||||
|
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)]
|
||||||
|
left_ids = left_ids[~torch.isin(left_ids, top_ids)]
|
||||||
|
right_ids = right_ids[~torch.isin(right_ids, top_ids)]
|
||||||
|
|
||||||
|
bottom_bc = temperature[bottom_ids].median()
|
||||||
|
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc
|
||||||
|
left_bc = temperature[left_ids].median()
|
||||||
|
left_bc_mask = torch.ones(len(left_ids)) * left_bc
|
||||||
|
right_bc = temperature[right_ids].median()
|
||||||
|
right_bc_mask = torch.ones(len(right_ids)) * right_bc
|
||||||
|
|
||||||
|
boundary_values = torch.cat(
|
||||||
|
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
|
||||||
|
)
|
||||||
|
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
|
||||||
|
|
||||||
|
return boundary_mask, boundary_values
|
||||||
|
|
||||||
def _build_dataset(
|
def _build_dataset(
|
||||||
self,
|
self,
|
||||||
snapshot: dict,
|
snapshot: dict,
|
||||||
@@ -66,27 +90,34 @@ class GraphDataModule(LightningDataModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||||
|
|
||||||
|
boundary_mask, boundary_values = self._compute_boundary_mask(
|
||||||
|
bottom_ids, right_ids, top_ids, left_ids, temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.remove_boundary_edges:
|
||||||
|
boundary_idx = torch.unique(boundary_mask)
|
||||||
|
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
|
||||||
|
edge_index = edge_index[:, edge_index_mask]
|
||||||
|
|
||||||
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
||||||
edge_attr = torch.cat(
|
edge_attr = torch.cat(
|
||||||
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
left_ids = left_ids[~torch.isin(left_ids, bottom_ids)]
|
x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1)
|
||||||
right_ids = right_ids[~torch.isin(right_ids, bottom_ids)]
|
if self.remove_boundary_edges:
|
||||||
left_ids = left_ids[~torch.isin(left_ids, top_ids)]
|
x[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||||
right_ids = right_ids[~torch.isin(right_ids, top_ids)]
|
return MeshData(
|
||||||
|
x=x,
|
||||||
bottom_bc = temperature[bottom_ids].median()
|
c=conductivity.unsqueeze(-1),
|
||||||
bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc
|
edge_index=edge_index,
|
||||||
left_bc = temperature[left_ids].median()
|
pos=pos,
|
||||||
left_bc_mask = torch.ones(len(left_ids)) * left_bc
|
edge_attr=edge_attr,
|
||||||
right_bc = temperature[right_ids].median()
|
y=temperature.unsqueeze(-1),
|
||||||
right_bc_mask = torch.ones(len(right_ids)) * right_bc
|
boundary_mask=boundary_mask,
|
||||||
|
boundary_values=torch.tensor(0),
|
||||||
boundary_values = torch.cat(
|
|
||||||
[bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0
|
|
||||||
)
|
)
|
||||||
boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0)
|
|
||||||
|
|
||||||
return MeshData(
|
return MeshData(
|
||||||
x=torch.rand_like(temperature).unsqueeze(-1),
|
x=torch.rand_like(temperature).unsqueeze(-1),
|
||||||
@@ -110,7 +141,6 @@ class GraphDataModule(LightningDataModule):
|
|||||||
if stage == "test" or stage is None:
|
if stage == "test" or stage is None:
|
||||||
self.test_data = self.data[val_end:]
|
self.test_data = self.data[val_end:]
|
||||||
|
|
||||||
# nel tuo LightningDataModule
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
self.train_data, batch_size=self.batch_size, shuffle=True
|
self.train_data, batch_size=self.batch_size, shuffle=True
|
||||||
|
|||||||
@@ -4,10 +4,6 @@ from torch_geometric.nn import MessagePassing
|
|||||||
from matplotlib.tri import Triangulation
|
from matplotlib.tri import Triangulation
|
||||||
|
|
||||||
|
|
||||||
def _import_boundary_conditions(x, boundary, boundary_mask):
|
|
||||||
x[boundary_mask] = boundary
|
|
||||||
|
|
||||||
|
|
||||||
def plot_results_fn(x, pos, i, batch):
|
def plot_results_fn(x, pos, i, batch):
|
||||||
x = x[batch == 0]
|
x = x[batch == 0]
|
||||||
pos = pos[batch == 0]
|
pos = pos[batch == 0]
|
||||||
@@ -63,7 +59,6 @@ class DecX(nn.Module):
|
|||||||
class ConditionalGNOBlock(MessagePassing):
|
class ConditionalGNOBlock(MessagePassing):
|
||||||
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
|
def __init__(self, hidden_ch, edge_ch=0, aggr="mean"):
|
||||||
super().__init__(aggr=aggr, node_dim=0)
|
super().__init__(aggr=aggr, node_dim=0)
|
||||||
# self.film_msg = FiLM(c_ch=hidden_ch, h_ch=hidden_ch)
|
|
||||||
|
|
||||||
self.edge_attr_net = nn.Sequential(
|
self.edge_attr_net = nn.Sequential(
|
||||||
nn.Linear(edge_ch, hidden_ch // 2),
|
nn.Linear(edge_ch, hidden_ch // 2),
|
||||||
@@ -73,9 +68,9 @@ class ConditionalGNOBlock(MessagePassing):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.msg_proj = nn.Sequential(
|
self.msg_proj = nn.Sequential(
|
||||||
nn.Linear(hidden_ch, hidden_ch),
|
nn.Linear(hidden_ch, hidden_ch, bias=False),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(hidden_ch, hidden_ch),
|
nn.Linear(hidden_ch, hidden_ch, bias=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.diff_net = nn.Sequential(
|
self.diff_net = nn.Sequential(
|
||||||
@@ -98,7 +93,15 @@ class ConditionalGNOBlock(MessagePassing):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.balancing = nn.Parameter(torch.tensor(0.0))
|
self.balancing = nn.Parameter(torch.tensor(0.0))
|
||||||
self.alpha = nn.Parameter(torch.tensor(1.0))
|
|
||||||
|
self.alpha_net = nn.Sequential(
|
||||||
|
nn.Linear(2 * hidden_ch, hidden_ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_ch, hidden_ch // 2),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_ch // 2, 1),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, c, edge_index, edge_attr=None):
|
def forward(self, x, c, edge_index, edge_attr=None):
|
||||||
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
return self.propagate(edge_index, x=x, c=c, edge_attr=edge_attr)
|
||||||
@@ -108,14 +111,14 @@ class ConditionalGNOBlock(MessagePassing):
|
|||||||
alpha = torch.sigmoid(self.balancing)
|
alpha = torch.sigmoid(self.balancing)
|
||||||
gate = torch.sigmoid(self.edge_attr_net(edge_attr))
|
gate = torch.sigmoid(self.edge_attr_net(edge_attr))
|
||||||
m = (
|
m = (
|
||||||
alpha * self.diff_net(x_j - x_i)
|
alpha * self.diff_net(x_j - x_i) + (1 - alpha) * self.x_net(x_j)
|
||||||
+ (1 - alpha) * self.x_net(x_j) * gate
|
) * gate
|
||||||
)
|
|
||||||
m = m * self.c_ij_net(c_ij)
|
m = m * self.c_ij_net(c_ij)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
def update(self, aggr_out, x):
|
def update(self, aggr_out, x):
|
||||||
return x + self.alpha * self.msg_proj(aggr_out)
|
alpha = self.alpha_net(torch.cat([x, aggr_out], dim=-1))
|
||||||
|
return x + alpha * self.msg_proj(aggr_out)
|
||||||
|
|
||||||
|
|
||||||
class GatingGNO(nn.Module):
|
class GatingGNO(nn.Module):
|
||||||
@@ -153,14 +156,10 @@ class GatingGNO(nn.Module):
|
|||||||
):
|
):
|
||||||
x = self.encoder_x(x)
|
x = self.encoder_x(x)
|
||||||
c = self.encoder_c(c)
|
c = self.encoder_c(c)
|
||||||
boundary = self.encoder_x(boundary)
|
|
||||||
if plot_results:
|
if plot_results:
|
||||||
_import_boundary_conditions(x, boundary, boundary_mask)
|
|
||||||
x_ = self.dec(x)
|
x_ = self.dec(x)
|
||||||
plot_results_fn(x_, pos, 0, batch=batch)
|
plot_results_fn(x_, pos, 0, batch=batch)
|
||||||
|
|
||||||
for _ in range(1, unrolling_steps + 1):
|
for _ in range(1, unrolling_steps + 1):
|
||||||
_import_boundary_conditions(x, boundary, boundary_mask)
|
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||||
if plot_results:
|
if plot_results:
|
||||||
|
|||||||
@@ -1,31 +1,26 @@
|
|||||||
import torch
|
import torch
|
||||||
from lightning import LightningModule
|
from lightning import LightningModule
|
||||||
from torch_geometric.data import Batch
|
from torch_geometric.data import Batch
|
||||||
from matplotlib.tri import Triangulation
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
# def plot_results(x, pos, step, i, batch):
|
def import_class(class_path: str):
|
||||||
# x = x[batch == 0]
|
module_path, class_name = class_path.rsplit(".", 1) # split last dot
|
||||||
# pos = pos[batch == 0]
|
module = importlib.import_module(module_path) # import the module
|
||||||
# tria = Triangulation(pos[:, 0].cpu(), pos[:, 1].cpu())
|
cls = getattr(module, class_name) # get the class
|
||||||
# import matplotlib.pyplot as plt
|
return cls
|
||||||
|
|
||||||
# plt.tricontourf(tria, x[:, 0].cpu(), levels=14)
|
|
||||||
# plt.colorbar()
|
|
||||||
# plt.savefig(f"{step:03d}_out_{i:03d}.png")
|
|
||||||
# plt.axis("equal")
|
|
||||||
# plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
class GraphSolver(LightningModule):
|
class GraphSolver(LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model_class_path: str,
|
||||||
|
model_init_args: dict,
|
||||||
loss: torch.nn.Module = None,
|
loss: torch.nn.Module = None,
|
||||||
unrolling_steps: int = 48,
|
unrolling_steps: int = 48,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = import_class(model_class_path)(**model_init_args)
|
||||||
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
||||||
self.unrolling_steps = unrolling_steps
|
self.unrolling_steps = unrolling_steps
|
||||||
|
|
||||||
@@ -57,7 +52,7 @@ class GraphSolver(LightningModule):
|
|||||||
|
|
||||||
def _log_loss(self, loss, batch, stage: str):
|
def _log_loss(self, loss, batch, stage: str):
|
||||||
self.log(
|
self.log(
|
||||||
f"{stage}_loss",
|
f"{stage}/loss",
|
||||||
loss,
|
loss,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
@@ -68,8 +63,6 @@ class GraphSolver(LightningModule):
|
|||||||
|
|
||||||
def training_step(self, batch: Batch, _):
|
def training_step(self, batch: Batch, _):
|
||||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||||
# x = self._impose_bc(x, batch)
|
|
||||||
# for _ in range(self.unrolling_steps):
|
|
||||||
y_pred = self(
|
y_pred = self(
|
||||||
x,
|
x,
|
||||||
c,
|
c,
|
||||||
@@ -79,9 +72,12 @@ class GraphSolver(LightningModule):
|
|||||||
edge_attr=edge_attr,
|
edge_attr=edge_attr,
|
||||||
unrolling_steps=self.unrolling_steps,
|
unrolling_steps=self.unrolling_steps,
|
||||||
)
|
)
|
||||||
# x = self._impose_bc(x, batch)
|
|
||||||
loss = self.loss(y_pred, y)
|
loss = self.loss(y_pred, y)
|
||||||
|
boundary_loss = self.loss(
|
||||||
|
y_pred[batch.boundary_mask], y[batch.boundary_mask]
|
||||||
|
)
|
||||||
self._log_loss(loss, batch, "train")
|
self._log_loss(loss, batch, "train")
|
||||||
|
self._log_loss(boundary_loss, batch, "train_boundary")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch: Batch, _):
|
def validation_step(self, batch: Batch, _):
|
||||||
@@ -96,12 +92,15 @@ class GraphSolver(LightningModule):
|
|||||||
unrolling_steps=self.unrolling_steps,
|
unrolling_steps=self.unrolling_steps,
|
||||||
)
|
)
|
||||||
loss = self.loss(y_pred, y)
|
loss = self.loss(y_pred, y)
|
||||||
|
boundary_loss = self.loss(
|
||||||
|
y_pred[batch.boundary_mask], y[batch.boundary_mask]
|
||||||
|
)
|
||||||
self._log_loss(loss, batch, "val")
|
self._log_loss(loss, batch, "val")
|
||||||
|
self._log_loss(boundary_loss, batch, "val_boundary")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def test_step(self, batch: Batch, _):
|
def test_step(self, batch: Batch, _):
|
||||||
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
x, y, c, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||||
# for _ in range(self.unrolling_steps):
|
|
||||||
y_pred = self.model(
|
y_pred = self.model(
|
||||||
x,
|
x,
|
||||||
c,
|
c,
|
||||||
@@ -114,8 +113,6 @@ class GraphSolver(LightningModule):
|
|||||||
batch=batch.batch,
|
batch=batch.batch,
|
||||||
pos=batch.pos,
|
pos=batch.pos,
|
||||||
)
|
)
|
||||||
# x = self._impose_bc(x, batch)
|
|
||||||
# plot_results(x, batch.pos, self.global_step, _, batch.batch)
|
|
||||||
loss = self._compute_loss(y_pred, y)
|
loss = self._compute_loss(y_pred, y)
|
||||||
self._log_loss(loss, batch, "test")
|
self._log_loss(loss, batch, "test")
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
Reference in New Issue
Block a user