diff --git a/ThermalSolver/model/local_gno.py b/ThermalSolver/model/local_gno.py index 1c93401..cf6a977 100644 --- a/ThermalSolver/model/local_gno.py +++ b/ThermalSolver/model/local_gno.py @@ -180,7 +180,6 @@ class GatingGNO(nn.Module): x = blk(x, c, edge_index, edge_attr=edge_attr) if plot_results: x_ = self.dec(x) - assert bc == x[boundary_mask] plot_results_fn(x_, pos, i * _, batch=batch) return self.dec(x) diff --git a/ThermalSolver/model/point_net.py b/ThermalSolver/model/point_net.py index 700e286..f0725e1 100644 --- a/ThermalSolver/model/point_net.py +++ b/ThermalSolver/model/point_net.py @@ -108,14 +108,14 @@ class MLP(torch.nn.Module): tmp_layers.append(self._output_dim) self._layers = [] - self._batchnorm = [] + self._LayerNorm = [] for i in range(len(tmp_layers) - 1): self._layers.append( self.spect_norm(nn.Linear(tmp_layers[i], tmp_layers[i + 1])) ) - self._batchnorm.append(nn.LazyBatchNorm1d()) + self._LayerNorm.append(nn.LazyLayerNorm()) if isinstance(func, list): self._functions = func @@ -124,7 +124,7 @@ class MLP(torch.nn.Module): unique_list = [] for layer, func, bnorm in zip( - self._layers[:-1], self._functions, self._batchnorm + self._layers[:-1], self._functions, self._LayerNorm ): unique_list.append(layer) @@ -208,7 +208,7 @@ class TNet(nn.Module): ) self._function = function() - self._bn1 = nn.LazyBatchNorm1d() + self._bn1 = nn.LazyLayerNorm() def forward(self, X): """Forward pass for T-Net @@ -299,9 +299,9 @@ class PointNet(nn.Module): self._tnet_feature = TNet(input_dim=64) self._function = function() - self._bn1 = nn.LazyBatchNorm1d() - self._bn2 = nn.LazyBatchNorm1d() - self._bn3 = nn.LazyBatchNorm1d() + self._bn1 = nn.LazyLayerNorm() + self._bn2 = nn.LazyLayerNorm() + self._bn3 = nn.LazyLayerNorm() def concat(self, embedding, input_): """Returns concatenation of global and local features for Point-Net @@ -370,205 +370,3 @@ class PointNet(nn.Module): X = self._mlp4(X) return X - - -class ConvTNet(nn.Module): - """T-Net base class. Implementation of T-Network with convolutional layers. - - Reference: Ali Kashefi et al. https://arxiv.org/abs/2208.13434 - """ - - def __init__(self, input_dim): - """T-Net block constructor - - :param input_dim: input dimension of point cloud - :type input_dim: int - """ - super().__init__() - - function = nn.Tanh - self._function = function() - - self._block1 = nn.Sequential( - nn.Conv1d(input_dim, 64, 1), - nn.BatchNorm1d(64), - self._function, - nn.Conv1d(64, 128, 1), - nn.BatchNorm1d(128), - self._function, - nn.Conv1d(128, 1024, 1), - nn.BatchNorm1d(1024), - self._function, - ) - - self._block2 = MLP( - input_dim=1024, - output_dim=input_dim * input_dim, - layers=[512, 256], - func=function, - batch_norm=True, - ) - - def forward(self, X): - """Forward pass for T-Net - - :param X: input tensor, shape [batch, $input_{dim}$, N] - with batch the batch size, N number of points and $input_{dim}$ - the input dimension of the point cloud. - :type X: torch.tensor - :return: output affine matrix transformation, shape - [batch, $input_{dim} \times input_{dim}$] with batch - the batch size and $input_{dim}$ the input dimension - of the point cloud. - :rtype: torch.tensor - """ - - batch, input_dim = X.shape[0], X.shape[1] - - # encoding using first MLP - X = self._block1(X) - - # applying symmetric function to aggregate information (using max as default) - X, _ = torch.max(X, dim=-1) - - # decoding using third MLP - X = self._block2(X) - - return X.reshape(batch, input_dim, input_dim) - - -class ConvPointNet(nn.Module): - """Point-Net base class. Implementation of Point Network for segmentation. - - Reference: Ali Kashefi et al. https://arxiv.org/abs/2208.13434 - """ - - def __init__(self, input_dim, output_dim, tnet=False): - """Point-Net block constructor - - :param input_dim: input dimension of point cloud - :type input_dim: int - :param output_dim: output dimension of point cloud - :type output_dim: int - :param tnet: apply T-Net transformation, defaults to False - :type tnet: bool, optional - """ - super().__init__() - - self._function = nn.Tanh() - self._use_tnet = tnet - - self._block1 = nn.Sequential( - nn.Conv1d(input_dim, 64, 1), - nn.BatchNorm1d(64), - self._function, - nn.Conv1d(64, 64, 1), - nn.BatchNorm1d(64), - self._function, - ) - - self._block2 = nn.Sequential( - nn.Conv1d(64, 64, 1), - nn.BatchNorm1d(64), - self._function, - nn.Conv1d(64, 128, 1), - nn.BatchNorm1d(128), - self._function, - nn.Conv1d(128, 1024, 1), - nn.BatchNorm1d(1024), - self._function, - ) - - self._block3 = nn.Sequential( - nn.Conv1d(1088, 512, 1), - nn.BatchNorm1d(512), - self._function, - nn.Conv1d(512, 256, 1), - nn.BatchNorm1d(256), - self._function, - nn.Conv1d(256, 128, 1), - nn.BatchNorm1d(128), - self._function, - ) - - self._block4 = nn.Conv1d(128, output_dim, 1) - - if self._use_tnet: - self._tnet_transform = ConvTNet(input_dim=input_dim) - self._tnet_feature = ConvTNet(input_dim=64) - - def concat(self, embedding, input_): - """ - Returns concatenation of global and local features for Point-Net - - :param embedding: global features of Point-Net, shape [batch, $input_{dim}$] - with batch the batch size and $input_{dim}$ the input dimension - of the point cloud. - :type embedding: torch.tensor - :param input_: local features of Point-Net, shape [batch, N, $input_{dim}$] - with batch the batch size, N number of points and $input_{dim}$ - the input dimension of the point cloud. - :type input_: torch.tensor - :return: concatenation vector, shape [batch, N, $input_{dim}$] - with batch the batch size, N number of points and $input_{dim}$ - :rtype: torch.tensor - """ - n_points = input_.shape[-1] - embedding = embedding.unsqueeze(2).repeat(1, 1, n_points) - return torch.cat([embedding, input_], dim=1) - - def forward(self, X): - """Forward pass for Point-Net - - :param X: input tensor, shape [batch, N, $input_{dim}$] - with batch the batch size, N number of points and $input_{dim}$ - the input dimension of the point cloud. - :type X: torch.tensor - :return: segmentation vector, shape [batch, N, $output_{dim}$] - with batch the batch size, N number of points and $output_{dim}$ - the output dimension of the point cloud. - :rtype: torch.tensor - """ - - # permuting indeces - X = X.permute(0, 2, 1) - - # using transform tnet if needed - if self._use_tnet: - transform = self._tnet_transform(X) - X = X.transpose(2, 1) - X = torch.matmul(X, transform) - X = X.transpose(2, 1) - - # encoding using first MLP - X = self._block1(X) - - # using transform tnet if needed - if self._use_tnet: - transform = self._tnet_feature(X) - X = X.transpose(2, 1) - X = torch.matmul(X, transform) - X = X.transpose(2, 1) - - # saving latent representation for later concatanation - latent = X - - # encoding using second MLP - X = self._block2(X) - - # applying symmetric function to aggregate information (using max as default) - X, _ = torch.max(X, dim=-1) - - # concatenating with latent vector - X = self.concat(X, latent) - - # decoding using third MLP - X = self._block3(X) - - # decoding using fourth MLP - X = self._block4(X) - - # permuting indeces - X = X.permute(0, 2, 1) - - return X