fix model
This commit is contained in:
@@ -7,6 +7,7 @@ from matplotlib.tri import Triangulation
|
||||
from .model.finite_difference import FiniteDifferenceStep
|
||||
import os
|
||||
|
||||
|
||||
def import_class(class_path: str):
|
||||
module_path, class_name = class_path.rsplit(".", 1) # split last dot
|
||||
module = importlib.import_module(module_path) # import the module
|
||||
@@ -14,7 +15,7 @@ def import_class(class_path: str):
|
||||
return cls
|
||||
|
||||
|
||||
def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
|
||||
def _plot_mesh(pos_, y_, y_pred_, y_true_, batch, i, batch_idx):
|
||||
for j in [0, 10, 20, 30]:
|
||||
idx = (batch == j).nonzero(as_tuple=True)[0]
|
||||
y = y_[idx].detach().cpu()
|
||||
@@ -49,6 +50,7 @@ def _plot_mesh(pos_, y_, y_pred_, y_true_ ,batch, i, batch_idx):
|
||||
plt.savefig(name, dpi=72)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _plot_losses(losses, batch_idx):
|
||||
folder = f"{batch_idx:02d}_images"
|
||||
plt.figure()
|
||||
@@ -74,8 +76,8 @@ class GraphSolver(LightningModule):
|
||||
super().__init__()
|
||||
self.model = import_class(model_class_path)(**model_init_args)
|
||||
# for param in self.model.parameters():
|
||||
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
||||
# print(f"Param: {param[0]}")
|
||||
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
||||
# print(f"Param: {param[0]}")
|
||||
self.loss = loss if loss is not None else torch.nn.MSELoss()
|
||||
self.unrolling_steps = unrolling_steps
|
||||
|
||||
@@ -101,29 +103,36 @@ class GraphSolver(LightningModule):
|
||||
return (0.5 * (c[edge_index[0]] + c[edge_index[1]])).squeeze()
|
||||
|
||||
def _compute_model_steps(
|
||||
self, x, edge_index, edge_attr, boundary_mask, boundary_values
|
||||
):
|
||||
out = self.model(x, edge_index, edge_attr)
|
||||
self,
|
||||
x,
|
||||
edge_index,
|
||||
edge_attr,
|
||||
boundary_mask,
|
||||
boundary_values,
|
||||
conductivity,
|
||||
):
|
||||
out = self.model(x, edge_index, edge_attr, conductivity)
|
||||
out[boundary_mask] = boundary_values.unsqueeze(-1)
|
||||
# print(torch.min(out), torch.max(out))
|
||||
return out
|
||||
|
||||
def _preprocess_batch(self, batch: Batch):
|
||||
x, y, c, edge_index, edge_attr = (
|
||||
x, y, c, edge_index, edge_attr, nodal_area = (
|
||||
batch.x,
|
||||
batch.y,
|
||||
batch.c,
|
||||
batch.edge_index,
|
||||
batch.edge_attr,
|
||||
batch.nodal_area,
|
||||
)
|
||||
edge_attr = 1 / edge_attr
|
||||
c_ij = self._compute_c_ij(c, edge_index)
|
||||
edge_attr = edge_attr * c_ij
|
||||
# edge_attr = edge_attr / torch.max(edge_attr)
|
||||
return x, y, edge_index, edge_attr
|
||||
conductivity = self._compute_c_ij(c, edge_index)
|
||||
edge_attr = edge_attr * conductivity
|
||||
return x, y, edge_index, edge_attr, conductivity
|
||||
|
||||
def training_step(self, batch: Batch):
|
||||
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
|
||||
batch
|
||||
)
|
||||
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
||||
losses = []
|
||||
# print(x.shape, y.shape)
|
||||
@@ -160,12 +169,13 @@ class GraphSolver(LightningModule):
|
||||
# deg,
|
||||
batch.boundary_mask,
|
||||
batch.boundary_values,
|
||||
conductivity,
|
||||
)
|
||||
x = out
|
||||
# print(out.shape, y[:, i, :].shape)
|
||||
losses.append(self.loss(out.flatten(), y[:, i, :].flatten()))
|
||||
# print(self.model.scale_edge_attr.item())
|
||||
|
||||
|
||||
loss = torch.stack(losses).mean()
|
||||
# for param in self.model.parameters():
|
||||
# print(f"Param: {param.shape}, Grad: {param.grad}")
|
||||
@@ -173,26 +183,40 @@ class GraphSolver(LightningModule):
|
||||
self._log_loss(loss, batch, "train")
|
||||
return loss
|
||||
|
||||
|
||||
def validation_step(self, batch: Batch, batch_idx):
|
||||
x, y, edge_index, edge_attr = self._preprocess_batch(batch)
|
||||
x, y, edge_index, edge_attr, conductivity = self._preprocess_batch(
|
||||
batch
|
||||
)
|
||||
# deg = self._compute_deg(edge_index, edge_attr, x.size(0))
|
||||
losses = []
|
||||
pos = batch.pos
|
||||
for i in range(self.unrolling_steps):
|
||||
out = self._compute_model_steps(
|
||||
# torch.cat([x,pos], dim=-1),
|
||||
x,
|
||||
edge_index,
|
||||
edge_attr,
|
||||
# deg,
|
||||
batch.boundary_mask,
|
||||
batch.boundary_values,
|
||||
# torch.cat([x,pos], dim=-1),
|
||||
x,
|
||||
edge_index,
|
||||
edge_attr,
|
||||
# deg,
|
||||
batch.boundary_mask,
|
||||
batch.boundary_values,
|
||||
conductivity,
|
||||
)
|
||||
if (batch_idx == 0 and self.current_epoch % 10 == 0 and self.current_epoch > 20):
|
||||
_plot_mesh(batch.pos, x, out, y[:, i, :], batch.batch, i, self.current_epoch)
|
||||
if (
|
||||
batch_idx == 0
|
||||
and self.current_epoch % 10 == 0
|
||||
and self.current_epoch > 0
|
||||
):
|
||||
_plot_mesh(
|
||||
batch.pos,
|
||||
x,
|
||||
out,
|
||||
y[:, i, :],
|
||||
batch.batch,
|
||||
i,
|
||||
self.current_epoch,
|
||||
)
|
||||
x = out
|
||||
losses.append(self.loss(out , y[:, i, :]))
|
||||
losses.append(self.loss(out, y[:, i, :]))
|
||||
|
||||
loss = torch.stack(losses).mean()
|
||||
self._log_loss(loss, batch, "val")
|
||||
@@ -202,5 +226,5 @@ class GraphSolver(LightningModule):
|
||||
pass
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(self.parameters(), lr=5e-3)
|
||||
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
Reference in New Issue
Block a user