diff --git a/ThermalSolver/model/learnable_finite_difference.py b/ThermalSolver/model/learnable_finite_difference.py index 3e69af9..94efba1 100644 --- a/ThermalSolver/model/learnable_finite_difference.py +++ b/ThermalSolver/model/learnable_finite_difference.py @@ -9,7 +9,8 @@ class FiniteDifferenceStep(MessagePassing): TODO: add docstring. """ - def __init__(self, edge_ch=5, hidden_dim=16, aggr: str = "add"): + def __init__(self, hidden_dim=16, aggr: str = "add"): + print(aggr) super().__init__(aggr=aggr) self.x_embedding = nn.Sequential( spectral_norm(nn.Linear(1, hidden_dim // 2)), @@ -17,12 +18,6 @@ class FiniteDifferenceStep(MessagePassing): spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)), ) - # self.update_net = nn.Sequential( - # spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)), - # nn.GELU(), - # spectral_norm(nn.Linear(hidden_dim, hidden_dim)), - # ) - self.out_net = nn.Sequential( spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)), nn.GELU(), @@ -35,20 +30,18 @@ class FiniteDifferenceStep(MessagePassing): """ x_ = self.x_embedding(x) out = self.propagate(edge_index, x=x_, edge_attr=edge_attr, deg=deg) - return self.out_net(x_ + out) + return self.out_net(out) - def message(self, x_i, x_j, edge_attr): + def message(self, x_j, edge_attr): """ TODO: add docstring. """ - return (x_j - x_i) * edge_attr.view(-1, 1) + return x_j * edge_attr.view(-1, 1) - def update(self, aggr_out, x): + def update(self, aggr_out, _): """ TODO: add docstring. """ - # update_input = torch.cat([x, aggr_out], dim=-1) - # return self.update_net(update_input) return aggr_out def aggregate(self, inputs, index, deg):