@@ -70,24 +70,24 @@ class Plotter:
|
||||
"""
|
||||
"""
|
||||
|
||||
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
|
||||
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T]
|
||||
|
||||
pred_output = pred.reshape(res, res)
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.detach(), **kwargs)
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[0])
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.detach(), **kwargs)
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[1])
|
||||
cb = getattr(ax[2], method)(*grids,
|
||||
(truth_output-pred_output).detach(),
|
||||
(truth_output-pred_output).cpu().detach(),
|
||||
**kwargs)
|
||||
fig.colorbar(cb, ax=ax[2])
|
||||
else:
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(ax, method)(*grids, pred_output.detach(), **kwargs)
|
||||
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
|
||||
|
||||
@@ -103,9 +103,13 @@ class Plotter:
|
||||
]
|
||||
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
|
||||
|
||||
for variable, value in fixed_variables.items():
|
||||
new = LabelTensor(torch.ones(pts.shape[0], 1)*value, [variable])
|
||||
pts = pts.append(new)
|
||||
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
|
||||
fixed_pts *= torch.tensor(list(fixed_variables.values()))
|
||||
fixed_pts = fixed_pts.as_subclass(LabelTensor)
|
||||
fixed_pts.labels = list(fixed_variables.keys())
|
||||
|
||||
pts = pts.append(fixed_pts)
|
||||
pts = pts.to(device=pinn.device)
|
||||
|
||||
predicted_output = pinn.model(pts)
|
||||
if isinstance(components, str):
|
||||
|
||||
Reference in New Issue
Block a user