This commit is contained in:
FilippoOlivo
2025-09-26 09:10:02 +02:00
parent f3be9e99f8
commit c6c416e682
2 changed files with 16 additions and 2 deletions

View File

@@ -40,6 +40,9 @@ class GraphDataModule(LightningDataModule):
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
:, :2
]
bottom_boundary_ids = torch.tensor(
self.geometry["bottom_boundary_ids"][0], dtype=torch.int64
)
self.data = [
self._build_dataset(
torch.tensor(snapshot["conductivity"], dtype=torch.float32),
@@ -47,6 +50,7 @@ class GraphDataModule(LightningDataModule):
torch.tensor(snapshot["temperature"], dtype=torch.float32),
edge_index.T,
pos,
bottom_boundary_ids,
)
for snapshot in tqdm(hf_dataset, desc="Building graphs")
]
@@ -58,13 +62,15 @@ class GraphDataModule(LightningDataModule):
temperature: torch.Tensor,
edge_index: torch.Tensor,
pos: torch.Tensor,
bottom_boundary_ids: torch.Tensor,
) -> Data:
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
edge_attr = torch.cat(
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
)
boundary_temperature = boundary_vales[bottom_boundary_ids].max()
boundary_vales[bottom_boundary_ids] = 1.0
return Data(
x=boundary_vales.unsqueeze(-1),
c=conductivity.unsqueeze(-1),
@@ -72,6 +78,7 @@ class GraphDataModule(LightningDataModule):
pos=pos,
edge_attr=edge_attr,
y=temperature.unsqueeze(-1),
boundary_temperature=boundary_vales[bottom_boundary_ids].max(),
)
def setup(self, stage: str = None):

View File

@@ -50,7 +50,10 @@ class GraphSolver(LightningModule):
for _ in range(self.unrolling_steps):
x_prev = x.detach()
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
loss += self.loss(x, y)
actual_loss = self.loss(x, y)
loss += actual_loss
print(f"Train step loss: {actual_loss.item()}")
self._log_loss(loss, batch, "train")
return loss
@@ -78,3 +81,7 @@ class GraphSolver(LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def scale_bc(self, data: Batch, y: torch.Tensor):
t = data.boundary_temperature[data.batch]
return y * t