minor
This commit is contained in:
@@ -40,6 +40,9 @@ class GraphDataModule(LightningDataModule):
|
|||||||
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
|
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
|
||||||
:, :2
|
:, :2
|
||||||
]
|
]
|
||||||
|
bottom_boundary_ids = torch.tensor(
|
||||||
|
self.geometry["bottom_boundary_ids"][0], dtype=torch.int64
|
||||||
|
)
|
||||||
self.data = [
|
self.data = [
|
||||||
self._build_dataset(
|
self._build_dataset(
|
||||||
torch.tensor(snapshot["conductivity"], dtype=torch.float32),
|
torch.tensor(snapshot["conductivity"], dtype=torch.float32),
|
||||||
@@ -47,6 +50,7 @@ class GraphDataModule(LightningDataModule):
|
|||||||
torch.tensor(snapshot["temperature"], dtype=torch.float32),
|
torch.tensor(snapshot["temperature"], dtype=torch.float32),
|
||||||
edge_index.T,
|
edge_index.T,
|
||||||
pos,
|
pos,
|
||||||
|
bottom_boundary_ids,
|
||||||
)
|
)
|
||||||
for snapshot in tqdm(hf_dataset, desc="Building graphs")
|
for snapshot in tqdm(hf_dataset, desc="Building graphs")
|
||||||
]
|
]
|
||||||
@@ -58,13 +62,15 @@ class GraphDataModule(LightningDataModule):
|
|||||||
temperature: torch.Tensor,
|
temperature: torch.Tensor,
|
||||||
edge_index: torch.Tensor,
|
edge_index: torch.Tensor,
|
||||||
pos: torch.Tensor,
|
pos: torch.Tensor,
|
||||||
|
bottom_boundary_ids: torch.Tensor,
|
||||||
) -> Data:
|
) -> Data:
|
||||||
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
|
||||||
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
|
||||||
edge_attr = torch.cat(
|
edge_attr = torch.cat(
|
||||||
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
|
||||||
)
|
)
|
||||||
|
boundary_temperature = boundary_vales[bottom_boundary_ids].max()
|
||||||
|
boundary_vales[bottom_boundary_ids] = 1.0
|
||||||
return Data(
|
return Data(
|
||||||
x=boundary_vales.unsqueeze(-1),
|
x=boundary_vales.unsqueeze(-1),
|
||||||
c=conductivity.unsqueeze(-1),
|
c=conductivity.unsqueeze(-1),
|
||||||
@@ -72,6 +78,7 @@ class GraphDataModule(LightningDataModule):
|
|||||||
pos=pos,
|
pos=pos,
|
||||||
edge_attr=edge_attr,
|
edge_attr=edge_attr,
|
||||||
y=temperature.unsqueeze(-1),
|
y=temperature.unsqueeze(-1),
|
||||||
|
boundary_temperature=boundary_vales[bottom_boundary_ids].max(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup(self, stage: str = None):
|
def setup(self, stage: str = None):
|
||||||
|
|||||||
@@ -50,7 +50,10 @@ class GraphSolver(LightningModule):
|
|||||||
for _ in range(self.unrolling_steps):
|
for _ in range(self.unrolling_steps):
|
||||||
x_prev = x.detach()
|
x_prev = x.detach()
|
||||||
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
|
x = self(x_prev, c, edge_index=edge_index, edge_attr=edge_attr)
|
||||||
loss += self.loss(x, y)
|
actual_loss = self.loss(x, y)
|
||||||
|
loss += actual_loss
|
||||||
|
print(f"Train step loss: {actual_loss.item()}")
|
||||||
|
|
||||||
self._log_loss(loss, batch, "train")
|
self._log_loss(loss, batch, "train")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -78,3 +81,7 @@ class GraphSolver(LightningModule):
|
|||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
def scale_bc(self, data: Batch, y: torch.Tensor):
|
||||||
|
t = data.boundary_temperature[data.batch]
|
||||||
|
return y * t
|
||||||
|
|||||||
Reference in New Issue
Block a user