Files
PINA/pina/operators.py
Your Name a05adea4e3 minor fix
2022-07-20 18:07:49 +02:00

168 lines
5.0 KiB
Python

"""Module for operators vectorize implementation"""
import torch
from pina.label_tensor import LabelTensor
def grad(output_, input_, components=None, d=None):
"""
TODO
"""
def grad_scalar_output(output_, input_, d):
"""
"""
if len(output_.labels) != 1:
raise RuntimeError
if not all([di in input_.labels for di in d]):
raise RuntimeError
output_fieldname = output_.labels[0]
gradients = torch.autograd.grad(
output_,
input_,
grad_outputs=torch.ones(output_.size()).to(
dtype=input_.dtype,
device=input_.device),
create_graph=True, retain_graph=True, allow_unused=True)[0]
gradients.labels = input_.labels
gradients = gradients.extract(d)
gradients.labels = [f'd{output_fieldname}d{i}' for i in d]
return gradients
if not isinstance(input_, LabelTensor):
raise TypeError
if d is None:
d = input_.labels
if components is None:
components = output_.labels
if output_.shape[1] == 1: # scalar output ################################
if components != output_.labels:
raise RuntimeError
gradients = grad_scalar_output(output_, input_, d)
elif output_.shape[1] >= 2: # vector output ##############################
for i, c in enumerate(components):
c_output = output_.extract([c])
if i == 0:
gradients = grad_scalar_output(c_output, input_, d)
else:
gradients = gradients.append(
grad_scalar_output(c_output, input_, d))
else:
raise NotImplementedError
return gradients
def div(output_, input_, components=None, d=None):
"""
TODO
"""
if not isinstance(input_, LabelTensor):
raise TypeError
if d is None:
d = input_.labels
if components is None:
components = output_.labels
if output_.shape[1] < 2 or len(components) < 2:
raise ValueError('div supported only for vector field')
if len(components) != len(d):
raise ValueError
grad_output = grad(output_, input_, components, d)
div = torch.zeros(input_.shape[0], 1)
# print(grad_output)
# print('empty', div)
labels = [None] * len(components)
for i, (c, d) in enumerate(zip(components, d)):
c_fields = f'd{c}d{d}'
# print(c_fields)
div[:, 0] += grad_output.extract(c_fields).sum(axis=1)
labels[i] = c_fields
# print('full', div)
# print(labels)
return LabelTensor(div, ['+'.join(labels)])
def nabla(output_, input_, components=None, d=None, method='std'):
"""
TODO
"""
if d is None:
d = input_.labels
if components is None:
components = output_.labels
if len(components) != len(d) and len(components) != 1:
raise ValueError
if method == 'divgrad':
raise NotImplementedError
# TODO fix
# grad_output = grad(output_, input_, components, d)
# result = div(grad_output, input_, d=d)
elif method == 'std':
if len(components) == 1:
grad_output = grad(output_, input_, components=components, d=d)
result = torch.zeros(output_.shape[0], 1)
for i, label in enumerate(grad_output.labels):
gg = grad(grad_output, input_, d=d, components=[label])
result[:, 0] += gg[:, i]
labels = [f'dd{components[0]}']
else:
result = torch.empty(input_.shape[0], len(components))
labels = [None] * len(components)
for idx, (ci, di) in enumerate(zip(components, d)):
if not isinstance(ci, list):
ci = [ci]
if not isinstance(di, list):
di = [di]
grad_output = grad(output_, input_, components=ci, d=di)
result[:, idx] = grad(grad_output, input_, d=di).flatten()
labels[idx] = f'dd{ci}dd{di}'
return LabelTensor(result, labels)
def advection(output_, input_):
"""
TODO
"""
dimension = len(output_.labels)
for i, label in enumerate(output_.labels):
# compute u dot gradient in each direction
gradient_loc = grad(output_.extract([label]),
input_).extract(input_.labels[:dimension])
dim_0 = gradient_loc.shape[0]
dim_1 = gradient_loc.shape[1]
u_dot_grad_loc = torch.bmm(output_.view(dim_0, 1, dim_1),
gradient_loc.view(dim_0, dim_1, 1))
u_dot_grad_loc = LabelTensor(torch.reshape(u_dot_grad_loc,
(u_dot_grad_loc.shape[0],
u_dot_grad_loc.shape[1])),
[input_.labels[i]])
if i == 0:
adv_term = u_dot_grad_loc
else:
adv_term = adv_term.append(u_dot_grad_loc)
return adv_term