minor
This commit is contained in:
@@ -40,6 +40,9 @@ class GraphDataModule(LightningDataModule):
|
||||
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
|
||||
:, :2
|
||||
]
|
||||
bottom_boundary_ids = torch.tensor(
|
||||
self.geometry["bottom_boundary_ids"][0], dtype=torch.int64
|
||||
)
|
||||
self.data = [
|
||||
self._build_dataset(
|
||||
torch.tensor(snapshot["conductivity"], dtype=torch.float32),
|
||||
@@ -47,6 +50,7 @@ class GraphDataModule(LightningDataModule):
|
||||
torch.tensor(snapshot["temperature"], dtype=torch.float32),
|
||||
edge_index.T,
|
||||
pos,
|
||||
bottom_boundary_ids,
|
||||
)
|
||||
for snapshot in tqdm(hf_dataset, desc="Building graphs")
|
||||
]
|
||||
@@ -58,13 +62,15 @@ class GraphDataModule(LightningDataModule):
|
||||
temperature: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
bottom_boundary_ids: torch.Tensor,
|
||||
) -> Data:
|
||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
||||
edge_attr = torch.cat(
|
||||
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
||||
)
|
||||
|
||||
boundary_temperature = boundary_vales[bottom_boundary_ids].max()
|
||||
boundary_vales[bottom_boundary_ids] = 1.0
|
||||
return Data(
|
||||
x=boundary_vales.unsqueeze(-1),
|
||||
c=conductivity.unsqueeze(-1),
|
||||
@@ -72,6 +78,7 @@ class GraphDataModule(LightningDataModule):
|
||||
pos=pos,
|
||||
edge_attr=edge_attr,
|
||||
y=temperature.unsqueeze(-1),
|
||||
boundary_temperature=boundary_vales[bottom_boundary_ids].max(),
|
||||
)
|
||||
|
||||
def setup(self, stage: str = None):
|
||||
|
||||
@@ -50,7 +50,10 @@ class GraphSolver(LightningModule):
|
||||
for _ in range(self.unrolling_steps):
|
||||
x_prev = x.detach()
|
||||
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
|
||||
loss += self.loss(x, y)
|
||||
actual_loss = self.loss(x, y)
|
||||
loss += actual_loss
|
||||
print(f"Train step loss: {actual_loss.item()}")
|
||||
|
||||
self._log_loss(loss, batch, "train")
|
||||
return loss
|
||||
|
||||
@@ -78,3 +81,7 @@ class GraphSolver(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
def scale_bc(self, data: Batch, y: torch.Tensor):
|
||||
t = data.boundary_temperature[data.batch]
|
||||
return y * t
|
||||
|
||||
Reference in New Issue
Block a user