@@ -19,14 +19,16 @@ def grad(output_, input_, components=None, d=None):
|
||||
raise RuntimeError
|
||||
|
||||
output_fieldname = output_.labels[0]
|
||||
|
||||
gradients = torch.autograd.grad(
|
||||
output_,
|
||||
input_,
|
||||
grad_outputs=torch.ones(output_.size()).to(
|
||||
dtype=input_.dtype,
|
||||
device=input_.device),
|
||||
create_graph=True, retain_graph=True, allow_unused=True)[0]
|
||||
grad_outputs=torch.ones(output_.size(), dtype=output_.dtype,
|
||||
device=output_.device),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
allow_unused=True
|
||||
)[0]
|
||||
|
||||
gradients.labels = input_.labels
|
||||
gradients = gradients.extract(d)
|
||||
gradients.labels = [f'd{output_fieldname}d{i}' for i in d]
|
||||
@@ -83,19 +85,16 @@ def div(output_, input_, components=None, d=None):
|
||||
raise ValueError
|
||||
|
||||
grad_output = grad(output_, input_, components, d)
|
||||
div = torch.zeros(input_.shape[0], 1)
|
||||
# print(grad_output)
|
||||
# print('empty', div)
|
||||
div = torch.zeros(input_.shape[0], 1, device=output_.device)
|
||||
labels = [None] * len(components)
|
||||
for i, (c, d) in enumerate(zip(components, d)):
|
||||
c_fields = f'd{c}d{d}'
|
||||
# print(c_fields)
|
||||
div[:, 0] += grad_output.extract(c_fields).sum(axis=1)
|
||||
labels[i] = c_fields
|
||||
# print('full', div)
|
||||
# print(labels)
|
||||
|
||||
return LabelTensor(div, ['+'.join(labels)])
|
||||
div = div.as_subclass(LabelTensor)
|
||||
div.labels = ['+'.join(labels)]
|
||||
return div
|
||||
|
||||
|
||||
def nabla(output_, input_, components=None, d=None, method='std'):
|
||||
@@ -120,14 +119,15 @@ def nabla(output_, input_, components=None, d=None, method='std'):
|
||||
|
||||
if len(components) == 1:
|
||||
grad_output = grad(output_, input_, components=components, d=d)
|
||||
result = torch.zeros(output_.shape[0], 1)
|
||||
result = torch.zeros(output_.shape[0], 1, device=output_.device)
|
||||
for i, label in enumerate(grad_output.labels):
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
result[:, 0] += gg[:, i]
|
||||
labels = [f'dd{components[0]}']
|
||||
|
||||
else:
|
||||
result = torch.empty(input_.shape[0], len(components))
|
||||
result = torch.empty(input_.shape[0], len(components),
|
||||
device=output_.device)
|
||||
labels = [None] * len(components)
|
||||
for idx, (ci, di) in enumerate(zip(components, d)):
|
||||
|
||||
@@ -140,28 +140,20 @@ def nabla(output_, input_, components=None, d=None, method='std'):
|
||||
result[:, idx] = grad(grad_output, input_, d=di).flatten()
|
||||
labels[idx] = f'dd{ci}dd{di}'
|
||||
|
||||
return LabelTensor(result, labels)
|
||||
result = result.as_subclass(LabelTensor)
|
||||
result.labels = labels
|
||||
return result
|
||||
|
||||
|
||||
def advection(output_, input_):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
dimension = len(output_.labels)
|
||||
for i, label in enumerate(output_.labels):
|
||||
# compute u dot gradient in each direction
|
||||
gradient_loc = grad(output_.extract([label]),
|
||||
input_).extract(input_.labels[:dimension])
|
||||
dim_0 = gradient_loc.shape[0]
|
||||
dim_1 = gradient_loc.shape[1]
|
||||
u_dot_grad_loc = torch.bmm(output_.view(dim_0, 1, dim_1),
|
||||
gradient_loc.view(dim_0, dim_1, 1))
|
||||
u_dot_grad_loc = LabelTensor(torch.reshape(u_dot_grad_loc,
|
||||
(u_dot_grad_loc.shape[0],
|
||||
u_dot_grad_loc.shape[1])),
|
||||
[input_.labels[i]])
|
||||
if i == 0:
|
||||
adv_term = u_dot_grad_loc
|
||||
else:
|
||||
adv_term = adv_term.append(u_dot_grad_loc)
|
||||
return adv_term
|
||||
def advection(output_, input_, velocity_field, components=None, d=None):
|
||||
if d is None:
|
||||
d = input_.labels
|
||||
|
||||
if components is None:
|
||||
components = output_.labels
|
||||
|
||||
tmp = grad(output_, input_, components, d
|
||||
).reshape(-1, len(components), len(d)).transpose(0, 1)
|
||||
|
||||
tmp *= output_.extract(velocity_field)
|
||||
return tmp.sum(dim=2).T
|
||||
|
||||
Reference in New Issue
Block a user