@@ -116,7 +116,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
new_data = self[:, indeces].float()
|
new_data = self[:, indeces].float()
|
||||||
new_labels = [self.labels[idx] for idx in indeces]
|
new_labels = [self.labels[idx] for idx in indeces]
|
||||||
extracted_tensor = LabelTensor(new_data, new_labels)
|
|
||||||
|
extracted_tensor = new_data.as_subclass(LabelTensor)
|
||||||
|
extracted_tensor.labels = new_labels
|
||||||
|
|
||||||
|
|
||||||
return extracted_tensor
|
return extracted_tensor
|
||||||
|
|
||||||
@@ -150,9 +153,15 @@ class LabelTensor(torch.Tensor):
|
|||||||
tensor2.repeat_interleave(n1, dim=0),
|
tensor2.repeat_interleave(n1, dim=0),
|
||||||
labels=tensor2.labels)
|
labels=tensor2.labels)
|
||||||
new_tensor = torch.cat((tensor1, tensor2), dim=1)
|
new_tensor = torch.cat((tensor1, tensor2), dim=1)
|
||||||
return LabelTensor(new_tensor, new_labels)
|
|
||||||
|
new_tensor = new_tensor.as_subclass(LabelTensor)
|
||||||
|
new_tensor.labels = new_labels
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
s = f'labels({str(self.labels)})\n'
|
if hasattr(self, 'labels'):
|
||||||
|
s = f'labels({str(self.labels)})\n'
|
||||||
|
else:
|
||||||
|
s = 'no labels\n'
|
||||||
s += super().__str__()
|
s += super().__str__()
|
||||||
return s
|
return s
|
||||||
|
|||||||
@@ -129,10 +129,10 @@ class DeepONet(torch.nn.Module):
|
|||||||
|
|
||||||
# output_ = self.reduction(inner_input)
|
# output_ = self.reduction(inner_input)
|
||||||
# print(output_.shape)
|
# print(output_.shape)
|
||||||
print(branch_output.shape)
|
|
||||||
print(trunk_output.shape)
|
|
||||||
output_ = self.reduction(trunk_output * branch_output)
|
output_ = self.reduction(trunk_output * branch_output)
|
||||||
output_ = LabelTensor(output_, self.output_variables)
|
# output_ = LabelTensor(output_, self.output_variables)
|
||||||
|
output_ = output_.as_subclass(LabelTensor)
|
||||||
|
output_.labels = self.output_variables
|
||||||
# local_size = int(trunk_output.shape[1]/self.output_dimension)
|
# local_size = int(trunk_output.shape[1]/self.output_dimension)
|
||||||
# for i, var in enumerate(self.output_variables):
|
# for i, var in enumerate(self.output_variables):
|
||||||
# start = i*local_size
|
# start = i*local_size
|
||||||
|
|||||||
@@ -97,9 +97,9 @@ class FeedForward(torch.nn.Module):
|
|||||||
for i, feature in enumerate(self.extra_features):
|
for i, feature in enumerate(self.extra_features):
|
||||||
x = x.append(feature(x))
|
x = x.append(feature(x))
|
||||||
|
|
||||||
output = self.model(x)
|
output = self.model(x).as_subclass(LabelTensor)
|
||||||
|
|
||||||
if self.output_variables:
|
if self.output_variables:
|
||||||
return LabelTensor(output, self.output_variables)
|
output.labels = self.output_variables
|
||||||
else:
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -19,14 +19,16 @@ def grad(output_, input_, components=None, d=None):
|
|||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
output_fieldname = output_.labels[0]
|
output_fieldname = output_.labels[0]
|
||||||
|
|
||||||
gradients = torch.autograd.grad(
|
gradients = torch.autograd.grad(
|
||||||
output_,
|
output_,
|
||||||
input_,
|
input_,
|
||||||
grad_outputs=torch.ones(output_.size()).to(
|
grad_outputs=torch.ones(output_.size(), dtype=output_.dtype,
|
||||||
dtype=input_.dtype,
|
device=output_.device),
|
||||||
device=input_.device),
|
create_graph=True,
|
||||||
create_graph=True, retain_graph=True, allow_unused=True)[0]
|
retain_graph=True,
|
||||||
|
allow_unused=True
|
||||||
|
)[0]
|
||||||
|
|
||||||
gradients.labels = input_.labels
|
gradients.labels = input_.labels
|
||||||
gradients = gradients.extract(d)
|
gradients = gradients.extract(d)
|
||||||
gradients.labels = [f'd{output_fieldname}d{i}' for i in d]
|
gradients.labels = [f'd{output_fieldname}d{i}' for i in d]
|
||||||
@@ -83,19 +85,16 @@ def div(output_, input_, components=None, d=None):
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
grad_output = grad(output_, input_, components, d)
|
grad_output = grad(output_, input_, components, d)
|
||||||
div = torch.zeros(input_.shape[0], 1)
|
div = torch.zeros(input_.shape[0], 1, device=output_.device)
|
||||||
# print(grad_output)
|
|
||||||
# print('empty', div)
|
|
||||||
labels = [None] * len(components)
|
labels = [None] * len(components)
|
||||||
for i, (c, d) in enumerate(zip(components, d)):
|
for i, (c, d) in enumerate(zip(components, d)):
|
||||||
c_fields = f'd{c}d{d}'
|
c_fields = f'd{c}d{d}'
|
||||||
# print(c_fields)
|
|
||||||
div[:, 0] += grad_output.extract(c_fields).sum(axis=1)
|
div[:, 0] += grad_output.extract(c_fields).sum(axis=1)
|
||||||
labels[i] = c_fields
|
labels[i] = c_fields
|
||||||
# print('full', div)
|
|
||||||
# print(labels)
|
|
||||||
|
|
||||||
return LabelTensor(div, ['+'.join(labels)])
|
div = div.as_subclass(LabelTensor)
|
||||||
|
div.labels = ['+'.join(labels)]
|
||||||
|
return div
|
||||||
|
|
||||||
|
|
||||||
def nabla(output_, input_, components=None, d=None, method='std'):
|
def nabla(output_, input_, components=None, d=None, method='std'):
|
||||||
@@ -120,14 +119,15 @@ def nabla(output_, input_, components=None, d=None, method='std'):
|
|||||||
|
|
||||||
if len(components) == 1:
|
if len(components) == 1:
|
||||||
grad_output = grad(output_, input_, components=components, d=d)
|
grad_output = grad(output_, input_, components=components, d=d)
|
||||||
result = torch.zeros(output_.shape[0], 1)
|
result = torch.zeros(output_.shape[0], 1, device=output_.device)
|
||||||
for i, label in enumerate(grad_output.labels):
|
for i, label in enumerate(grad_output.labels):
|
||||||
gg = grad(grad_output, input_, d=d, components=[label])
|
gg = grad(grad_output, input_, d=d, components=[label])
|
||||||
result[:, 0] += gg[:, i]
|
result[:, 0] += gg[:, i]
|
||||||
labels = [f'dd{components[0]}']
|
labels = [f'dd{components[0]}']
|
||||||
|
|
||||||
else:
|
else:
|
||||||
result = torch.empty(input_.shape[0], len(components))
|
result = torch.empty(input_.shape[0], len(components),
|
||||||
|
device=output_.device)
|
||||||
labels = [None] * len(components)
|
labels = [None] * len(components)
|
||||||
for idx, (ci, di) in enumerate(zip(components, d)):
|
for idx, (ci, di) in enumerate(zip(components, d)):
|
||||||
|
|
||||||
@@ -140,28 +140,20 @@ def nabla(output_, input_, components=None, d=None, method='std'):
|
|||||||
result[:, idx] = grad(grad_output, input_, d=di).flatten()
|
result[:, idx] = grad(grad_output, input_, d=di).flatten()
|
||||||
labels[idx] = f'dd{ci}dd{di}'
|
labels[idx] = f'dd{ci}dd{di}'
|
||||||
|
|
||||||
return LabelTensor(result, labels)
|
result = result.as_subclass(LabelTensor)
|
||||||
|
result.labels = labels
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def advection(output_, input_):
|
def advection(output_, input_, velocity_field, components=None, d=None):
|
||||||
"""
|
if d is None:
|
||||||
TODO
|
d = input_.labels
|
||||||
"""
|
|
||||||
dimension = len(output_.labels)
|
if components is None:
|
||||||
for i, label in enumerate(output_.labels):
|
components = output_.labels
|
||||||
# compute u dot gradient in each direction
|
|
||||||
gradient_loc = grad(output_.extract([label]),
|
tmp = grad(output_, input_, components, d
|
||||||
input_).extract(input_.labels[:dimension])
|
).reshape(-1, len(components), len(d)).transpose(0, 1)
|
||||||
dim_0 = gradient_loc.shape[0]
|
|
||||||
dim_1 = gradient_loc.shape[1]
|
tmp *= output_.extract(velocity_field)
|
||||||
u_dot_grad_loc = torch.bmm(output_.view(dim_0, 1, dim_1),
|
return tmp.sum(dim=2).T
|
||||||
gradient_loc.view(dim_0, dim_1, 1))
|
|
||||||
u_dot_grad_loc = LabelTensor(torch.reshape(u_dot_grad_loc,
|
|
||||||
(u_dot_grad_loc.shape[0],
|
|
||||||
u_dot_grad_loc.shape[1])),
|
|
||||||
[input_.labels[i]])
|
|
||||||
if i == 0:
|
|
||||||
adv_term = u_dot_grad_loc
|
|
||||||
else:
|
|
||||||
adv_term = adv_term.append(u_dot_grad_loc)
|
|
||||||
return adv_term
|
|
||||||
|
|||||||
@@ -70,24 +70,24 @@ class Plotter:
|
|||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
|
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T]
|
||||||
|
|
||||||
pred_output = pred.reshape(res, res)
|
pred_output = pred.reshape(res, res)
|
||||||
if truth_solution:
|
if truth_solution:
|
||||||
truth_output = truth_solution(pts).float().reshape(res, res)
|
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||||
|
|
||||||
cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
|
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax[0])
|
fig.colorbar(cb, ax=ax[0])
|
||||||
cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
|
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax[1])
|
fig.colorbar(cb, ax=ax[1])
|
||||||
cb = getattr(ax[2], method)(*grids,
|
cb = getattr(ax[2], method)(*grids,
|
||||||
(truth_output-pred_output).detach(),
|
(truth_output-pred_output).cpu().detach(),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
fig.colorbar(cb, ax=ax[2])
|
fig.colorbar(cb, ax=ax[2])
|
||||||
else:
|
else:
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
|
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), **kwargs)
|
||||||
fig.colorbar(cb, ax=ax)
|
fig.colorbar(cb, ax=ax)
|
||||||
|
|
||||||
|
|
||||||
@@ -103,9 +103,13 @@ class Plotter:
|
|||||||
]
|
]
|
||||||
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
|
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
|
||||||
|
|
||||||
for variable, value in fixed_variables.items():
|
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
|
||||||
new = LabelTensor(torch.ones(pts.shape[0], 1)*value, [variable])
|
fixed_pts *= torch.tensor(list(fixed_variables.values()))
|
||||||
pts = pts.append(new)
|
fixed_pts = fixed_pts.as_subclass(LabelTensor)
|
||||||
|
fixed_pts.labels = list(fixed_variables.keys())
|
||||||
|
|
||||||
|
pts = pts.append(fixed_pts)
|
||||||
|
pts = pts.to(device=pinn.device)
|
||||||
|
|
||||||
predicted_output = pinn.model(pts)
|
predicted_output = pinn.model(pts)
|
||||||
if isinstance(components, str):
|
if isinstance(components, str):
|
||||||
|
|||||||
Reference in New Issue
Block a user