This commit is contained in:
FilippoOlivo
2025-09-26 09:10:02 +02:00
parent f3be9e99f8
commit c6c416e682
2 changed files with 16 additions and 2 deletions

View File

@@ -40,6 +40,9 @@ class GraphDataModule(LightningDataModule):
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[ pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
:, :2 :, :2
] ]
bottom_boundary_ids = torch.tensor(
self.geometry["bottom_boundary_ids"][0], dtype=torch.int64
)
self.data = [ self.data = [
self._build_dataset( self._build_dataset(
torch.tensor(snapshot["conductivity"], dtype=torch.float32), torch.tensor(snapshot["conductivity"], dtype=torch.float32),
@@ -47,6 +50,7 @@ class GraphDataModule(LightningDataModule):
torch.tensor(snapshot["temperature"], dtype=torch.float32), torch.tensor(snapshot["temperature"], dtype=torch.float32),
edge_index.T, edge_index.T,
pos, pos,
bottom_boundary_ids,
) )
for snapshot in tqdm(hf_dataset, desc="Building graphs") for snapshot in tqdm(hf_dataset, desc="Building graphs")
] ]
@@ -58,13 +62,15 @@ class GraphDataModule(LightningDataModule):
temperature: torch.Tensor, temperature: torch.Tensor,
edge_index: torch.Tensor, edge_index: torch.Tensor,
pos: torch.Tensor, pos: torch.Tensor,
bottom_boundary_ids: torch.Tensor,
) -> Data: ) -> Data:
edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
edge_attr = pos[edge_index[0]] - pos[edge_index[1]] edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
edge_attr = torch.cat( edge_attr = torch.cat(
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
) )
boundary_temperature = boundary_vales[bottom_boundary_ids].max()
boundary_vales[bottom_boundary_ids] = 1.0
return Data( return Data(
x=boundary_vales.unsqueeze(-1), x=boundary_vales.unsqueeze(-1),
c=conductivity.unsqueeze(-1), c=conductivity.unsqueeze(-1),
@@ -72,6 +78,7 @@ class GraphDataModule(LightningDataModule):
pos=pos, pos=pos,
edge_attr=edge_attr, edge_attr=edge_attr,
y=temperature.unsqueeze(-1), y=temperature.unsqueeze(-1),
boundary_temperature=boundary_vales[bottom_boundary_ids].max(),
) )
def setup(self, stage: str = None): def setup(self, stage: str = None):

View File

@@ -50,7 +50,10 @@ class GraphSolver(LightningModule):
for _ in range(self.unrolling_steps): for _ in range(self.unrolling_steps):
x_prev = x.detach() x_prev = x.detach()
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr) x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
loss += self.loss(x, y) actual_loss = self.loss(x, y)
loss += actual_loss
print(f"Train step loss: {actual_loss.item()}")
self._log_loss(loss, batch, "train") self._log_loss(loss, batch, "train")
return loss return loss
@@ -78,3 +81,7 @@ class GraphSolver(LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer return optimizer
def scale_bc(self, data: Batch, y: torch.Tensor):
t = data.boundary_temperature[data.batch]
return y * t