CUDA option for labeltensor (#23)

* fix cuda device for labeltensor
This commit is contained in:
Nicola Demo
2022-09-08 17:31:49 +02:00
committed by GitHub
parent 9b2ab7be41
commit 06932196a8
5 changed files with 61 additions and 56 deletions

View File

@@ -70,24 +70,24 @@ class Plotter:
"""
"""
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T]
pred_output = pred.reshape(res, res)
if truth_solution:
truth_output = truth_solution(pts).float().reshape(res, res)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax[0])
cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax[1])
cb = getattr(ax[2], method)(*grids,
(truth_output-pred_output).detach(),
(truth_output-pred_output).cpu().detach(),
**kwargs)
fig.colorbar(cb, ax=ax[2])
else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax)
@@ -103,9 +103,13 @@ class Plotter:
]
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
for variable, value in fixed_variables.items():
new = LabelTensor(torch.ones(pts.shape[0], 1)*value, [variable])
pts = pts.append(new)
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
fixed_pts *= torch.tensor(list(fixed_variables.values()))
fixed_pts = fixed_pts.as_subclass(LabelTensor)
fixed_pts.labels = list(fixed_variables.keys())
pts = pts.append(fixed_pts)
pts = pts.to(device=pinn.device)
predicted_output = pinn.model(pts)
if isinstance(components, str):