equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
@@ -188,7 +188,7 @@ def nabla(output_, input_, components=None, d=None, method='std'):
|
||||
result = torch.zeros(output_.shape[0], 1, device=output_.device)
|
||||
for i, label in enumerate(grad_output.labels):
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
result[:, 0] += gg[:, i]
|
||||
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(i) # TODO improve
|
||||
labels = [f'dd{components[0]}']
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user