From 06932196a80afd4810f052e877355058302b91e8 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Thu, 8 Sep 2022 17:31:49 +0200 Subject: [PATCH] CUDA option for labeltensor (#23) * fix cuda device for labeltensor --- pina/label_tensor.py | 15 +++++++-- pina/model/deeponet.py | 6 ++-- pina/model/feed_forward.py | 8 ++--- pina/operators.py | 68 +++++++++++++++++--------------------- pina/plotter.py | 20 ++++++----- 5 files changed, 61 insertions(+), 56 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 334d740..0803fb4 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -116,7 +116,10 @@ class LabelTensor(torch.Tensor): new_data = self[:, indeces].float() new_labels = [self.labels[idx] for idx in indeces] - extracted_tensor = LabelTensor(new_data, new_labels) + + extracted_tensor = new_data.as_subclass(LabelTensor) + extracted_tensor.labels = new_labels + return extracted_tensor @@ -150,9 +153,15 @@ class LabelTensor(torch.Tensor): tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels) new_tensor = torch.cat((tensor1, tensor2), dim=1) - return LabelTensor(new_tensor, new_labels) + + new_tensor = new_tensor.as_subclass(LabelTensor) + new_tensor.labels = new_labels + return new_tensor def __str__(self): - s = f'labels({str(self.labels)})\n' + if hasattr(self, 'labels'): + s = f'labels({str(self.labels)})\n' + else: + s = 'no labels\n' s += super().__str__() return s diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index fddbe8d..942f92f 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -129,10 +129,10 @@ class DeepONet(torch.nn.Module): # output_ = self.reduction(inner_input) # print(output_.shape) - print(branch_output.shape) - print(trunk_output.shape) output_ = self.reduction(trunk_output * branch_output) - output_ = LabelTensor(output_, self.output_variables) + # output_ = LabelTensor(output_, self.output_variables) + output_ = output_.as_subclass(LabelTensor) + output_.labels = self.output_variables # local_size = int(trunk_output.shape[1]/self.output_dimension) # for i, var in enumerate(self.output_variables): # start = i*local_size diff --git a/pina/model/feed_forward.py b/pina/model/feed_forward.py index 7639bf2..ac5854c 100644 --- a/pina/model/feed_forward.py +++ b/pina/model/feed_forward.py @@ -97,9 +97,9 @@ class FeedForward(torch.nn.Module): for i, feature in enumerate(self.extra_features): x = x.append(feature(x)) - output = self.model(x) + output = self.model(x).as_subclass(LabelTensor) if self.output_variables: - return LabelTensor(output, self.output_variables) - else: - return output + output.labels = self.output_variables + + return output diff --git a/pina/operators.py b/pina/operators.py index cd1456f..b816deb 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -19,14 +19,16 @@ def grad(output_, input_, components=None, d=None): raise RuntimeError output_fieldname = output_.labels[0] - gradients = torch.autograd.grad( output_, input_, - grad_outputs=torch.ones(output_.size()).to( - dtype=input_.dtype, - device=input_.device), - create_graph=True, retain_graph=True, allow_unused=True)[0] + grad_outputs=torch.ones(output_.size(), dtype=output_.dtype, + device=output_.device), + create_graph=True, + retain_graph=True, + allow_unused=True + )[0] + gradients.labels = input_.labels gradients = gradients.extract(d) gradients.labels = [f'd{output_fieldname}d{i}' for i in d] @@ -83,19 +85,16 @@ def div(output_, input_, components=None, d=None): raise ValueError grad_output = grad(output_, input_, components, d) - div = torch.zeros(input_.shape[0], 1) - # print(grad_output) - # print('empty', div) + div = torch.zeros(input_.shape[0], 1, device=output_.device) labels = [None] * len(components) for i, (c, d) in enumerate(zip(components, d)): c_fields = f'd{c}d{d}' - # print(c_fields) div[:, 0] += grad_output.extract(c_fields).sum(axis=1) labels[i] = c_fields - # print('full', div) - # print(labels) - return LabelTensor(div, ['+'.join(labels)]) + div = div.as_subclass(LabelTensor) + div.labels = ['+'.join(labels)] + return div def nabla(output_, input_, components=None, d=None, method='std'): @@ -120,14 +119,15 @@ def nabla(output_, input_, components=None, d=None, method='std'): if len(components) == 1: grad_output = grad(output_, input_, components=components, d=d) - result = torch.zeros(output_.shape[0], 1) + result = torch.zeros(output_.shape[0], 1, device=output_.device) for i, label in enumerate(grad_output.labels): - gg = grad(grad_output, input_, d=d, components=[label]) + gg = grad(grad_output, input_, d=d, components=[label]) result[:, 0] += gg[:, i] labels = [f'dd{components[0]}'] else: - result = torch.empty(input_.shape[0], len(components)) + result = torch.empty(input_.shape[0], len(components), + device=output_.device) labels = [None] * len(components) for idx, (ci, di) in enumerate(zip(components, d)): @@ -140,28 +140,20 @@ def nabla(output_, input_, components=None, d=None, method='std'): result[:, idx] = grad(grad_output, input_, d=di).flatten() labels[idx] = f'dd{ci}dd{di}' - return LabelTensor(result, labels) + result = result.as_subclass(LabelTensor) + result.labels = labels + return result -def advection(output_, input_): - """ - TODO - """ - dimension = len(output_.labels) - for i, label in enumerate(output_.labels): - # compute u dot gradient in each direction - gradient_loc = grad(output_.extract([label]), - input_).extract(input_.labels[:dimension]) - dim_0 = gradient_loc.shape[0] - dim_1 = gradient_loc.shape[1] - u_dot_grad_loc = torch.bmm(output_.view(dim_0, 1, dim_1), - gradient_loc.view(dim_0, dim_1, 1)) - u_dot_grad_loc = LabelTensor(torch.reshape(u_dot_grad_loc, - (u_dot_grad_loc.shape[0], - u_dot_grad_loc.shape[1])), - [input_.labels[i]]) - if i == 0: - adv_term = u_dot_grad_loc - else: - adv_term = adv_term.append(u_dot_grad_loc) - return adv_term +def advection(output_, input_, velocity_field, components=None, d=None): + if d is None: + d = input_.labels + + if components is None: + components = output_.labels + + tmp = grad(output_, input_, components, d + ).reshape(-1, len(components), len(d)).transpose(0, 1) + + tmp *= output_.extract(velocity_field) + return tmp.sum(dim=2).T diff --git a/pina/plotter.py b/pina/plotter.py index 40fcc3e..7881fda 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -70,24 +70,24 @@ class Plotter: """ """ - grids = [p_.reshape(res, res) for p_ in pts.extract(v).T] + grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T] pred_output = pred.reshape(res, res) if truth_solution: truth_output = truth_solution(pts).float().reshape(res, res) fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) - cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs) + cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), **kwargs) fig.colorbar(cb, ax=ax[0]) - cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs) + cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), **kwargs) fig.colorbar(cb, ax=ax[1]) cb = getattr(ax[2], method)(*grids, - (truth_output-pred_output).detach(), + (truth_output-pred_output).cpu().detach(), **kwargs) fig.colorbar(cb, ax=ax[2]) else: fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) - cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs) + cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), **kwargs) fig.colorbar(cb, ax=ax) @@ -103,9 +103,13 @@ class Plotter: ] pts = pinn.problem.domain.sample(res, 'grid', variables=v) - for variable, value in fixed_variables.items(): - new = LabelTensor(torch.ones(pts.shape[0], 1)*value, [variable]) - pts = pts.append(new) + fixed_pts = torch.ones(pts.shape[0], len(fixed_variables)) + fixed_pts *= torch.tensor(list(fixed_variables.values())) + fixed_pts = fixed_pts.as_subclass(LabelTensor) + fixed_pts.labels = list(fixed_variables.keys()) + + pts = pts.append(fixed_pts) + pts = pts.to(device=pinn.device) predicted_output = pinn.model(pts) if isinstance(components, str):