Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -5,6 +5,7 @@ All operator take as input a tensor onto which computing the operator, a tensor
|
||||
to which computing the operator, the name of the output variables to calculate the operator
|
||||
for (in case of multidimensional functions), and the variables name on which the operator is calculated.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from pina.label_tensor import LabelTensor
|
||||
|
||||
@@ -56,9 +57,9 @@ def grad(output_, input_, components=None, d=None):
|
||||
gradients = torch.autograd.grad(
|
||||
output_,
|
||||
input_,
|
||||
grad_outputs=torch.ones(output_.size(),
|
||||
dtype=output_.dtype,
|
||||
device=output_.device),
|
||||
grad_outputs=torch.ones(
|
||||
output_.size(), dtype=output_.dtype, device=output_.device
|
||||
),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
allow_unused=True,
|
||||
@@ -83,8 +84,9 @@ def grad(output_, input_, components=None, d=None):
|
||||
raise RuntimeError
|
||||
gradients = grad_scalar_output(output_, input_, d)
|
||||
|
||||
elif output_.shape[output_.ndim -
|
||||
1] >= 2: # vector output ##############################
|
||||
elif (
|
||||
output_.shape[output_.ndim - 1] >= 2
|
||||
): # vector output ##############################
|
||||
tensor_to_cat = []
|
||||
for i, c in enumerate(components):
|
||||
c_output = output_.extract([c])
|
||||
@@ -253,8 +255,11 @@ def advection(output_, input_, velocity_field, components=None, d=None):
|
||||
if components is None:
|
||||
components = output_.labels
|
||||
|
||||
tmp = (grad(output_, input_, components, d).reshape(-1, len(components),
|
||||
len(d)).transpose(0, 1))
|
||||
tmp = (
|
||||
grad(output_, input_, components, d)
|
||||
.reshape(-1, len(components), len(d))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
tmp *= output_.extract(velocity_field)
|
||||
return tmp.sum(dim=2).T
|
||||
|
||||
Reference in New Issue
Block a user