add final loss and change model

This commit is contained in:
Filippo Olivo
2025-11-13 16:18:54 +01:00
parent dc59114f4a
commit ea9cf7c57c
2 changed files with 26 additions and 31 deletions

View File

@@ -17,11 +17,11 @@ class FiniteDifferenceStep(MessagePassing):
spectral_norm(nn.Linear(hidden_dim // 2, hidden_dim)),
)
self.update_net = nn.Sequential(
spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
)
# self.update_net = nn.Sequential(
# spectral_norm(nn.Linear(2 * hidden_dim, hidden_dim)),
# nn.GELU(),
# spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
# )
self.out_net = nn.Sequential(
spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)),
@@ -47,8 +47,9 @@ class FiniteDifferenceStep(MessagePassing):
"""
TODO: add docstring.
"""
update_input = torch.cat([x, aggr_out], dim=-1)
return self.update_net(update_input)
# update_input = torch.cat([x, aggr_out], dim=-1)
# return self.update_net(update_input)
return aggr_out
def aggregate(self, inputs, index, deg):
"""