import torch from lightning import LightningModule import importlib from matplotlib import pyplot as plt from matplotlib.tri import Triangulation def _plot_mesh(x, y, y_pred): x = x[0, ...].detach().cpu() pos = x[0, ...].detach().cpu() pos = x[x[:, 0] != -1] y = y[0, ...].detach().cpu() y = y[x[:, 0] != -1] y_pred = y_pred[0, ...].detach().cpu() y_pred = y_pred[x[:, 0] != -1] tria = Triangulation(pos[:, 2], pos[:, 3]) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.tricontourf(tria, y.squeeze().numpy(), levels=14) plt.colorbar() plt.title("True temperature") plt.subplot(1, 3, 2) plt.tricontourf(tria, y_pred.squeeze().numpy(), levels=14) plt.colorbar() plt.title("Predicted temperature") plt.subplot(1, 3, 3) plt.tricontourf(tria, torch.abs(y_pred - y).squeeze().numpy(), levels=14) plt.colorbar() plt.title("Error") plt.suptitle("PointNet", fontsize=16) plt.savefig("point_net.png", dpi=300) def import_class(class_path: str): module_path, class_name = class_path.rsplit(".", 1) # split last dot module = importlib.import_module(module_path) # import the module cls = getattr(module, class_name) # get the class return cls class PointSolver(LightningModule): def __init__( self, model_class_path: str, model_init_args: dict, loss: torch.nn.Module = None, ): super().__init__() self.model = import_class(model_class_path)(**model_init_args) self.loss = loss if loss is not None else torch.nn.MSELoss() def forward( self, x: torch.Tensor, ): return self.model(x) def _compute_loss(self, x, y): return self.loss(x, y) def _log_loss(self, loss, batch, stage: str): self.log( f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch), ) return loss def training_step(self, batch, _): x, y = batch y_pred = self(x) loss = self.loss(y_pred, y) self._log_loss(loss, batch, "train") return loss def validation_step(self, batch, _): x, y = batch y_pred = self(x) loss = self.loss(y_pred, y) self._log_loss(loss, batch, "val") return loss def test_step(self, batch, _): x, y = batch y_pred = self.model(x) loss = self._compute_loss(y_pred, y) self._log_loss(loss, batch, "test") _plot_mesh(x, y, y_pred) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer