Gradient accumulation in BPTT (#2)

This commit is contained in:
2025-11-11 20:14:28 +01:00
committed by GitHub
parent 195c66b444
commit a2dd348423
4 changed files with 292 additions and 179 deletions

View File

@@ -14,7 +14,7 @@ class FiniteDifferenceStep(MessagePassing):
aggr == "add"
), "Per somme pesate, l'aggregazione deve essere 'add'."
# self.root_weight = float(root_weight)
self.p = torch.nn.Parameter(torch.tensor(0.8))
self.p = torch.nn.Parameter(torch.tensor(1.0))
self.a = root_weight
def forward(self, x, edge_index, edge_attr, deg):
@@ -43,9 +43,7 @@ class FiniteDifferenceStep(MessagePassing):
"""
TODO: add docstring.
"""
a = torch.clamp(self.a, 0.0, 1.0)
return a * aggr_out + (1 - a) * x
# return self.a * aggr_out + (1 - self.a) * x
return aggr_out
class GraphFiniteDifference(nn.Module):