fix codacy
This commit is contained in:
committed by
Nicola Demo
parent
513144dfaf
commit
b38c640c4d
@@ -15,7 +15,8 @@ class Plotter:
|
|||||||
"""
|
"""
|
||||||
Plot the training grid samples.
|
Plot the training grid samples.
|
||||||
|
|
||||||
:param AbstractProblem problem: The PINA problem from where to plot the domain.
|
:param AbstractProblem problem: The PINA problem from where to plot
|
||||||
|
the domain.
|
||||||
:param list(str) variables: Variables to plot. If None, all variables
|
:param list(str) variables: Variables to plot. If None, all variables
|
||||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
are plotted. If 'spatial', only spatial variables are plotted. If
|
||||||
'temporal', only temporal variables are plotted. Defaults to None.
|
'temporal', only temporal variables are plotted. Defaults to None.
|
||||||
@@ -39,7 +40,8 @@ class Plotter:
|
|||||||
variables = problem.temporal_domain.variables
|
variables = problem.temporal_domain.variables
|
||||||
|
|
||||||
if len(variables) not in [1, 2, 3]:
|
if len(variables) not in [1, 2, 3]:
|
||||||
raise ValueError('Samples can be plotted only in dimensions 1, 2 and 3.')
|
raise ValueError('Samples can be plotted only in '
|
||||||
|
'dimensions 1, 2 and 3.')
|
||||||
|
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
proj = '3d' if len(variables) == 3 else None
|
proj = '3d' if len(variables) == 3 else None
|
||||||
@@ -96,7 +98,8 @@ class Plotter:
|
|||||||
|
|
||||||
if truth_solution:
|
if truth_solution:
|
||||||
truth_output = truth_solution(pts).detach()
|
truth_output = truth_solution(pts).detach()
|
||||||
ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs)
|
ax.plot(pts.extract(v), truth_output,
|
||||||
|
label='True solution', **kwargs)
|
||||||
|
|
||||||
# TODO: pred is a torch.Tensor, so no labels is available
|
# TODO: pred is a torch.Tensor, so no labels is available
|
||||||
# extra variable for labels should be
|
# extra variable for labels should be
|
||||||
@@ -197,7 +200,8 @@ class Plotter:
|
|||||||
|
|
||||||
if len(components) > 1:
|
if len(components) > 1:
|
||||||
raise NotImplementedError('Multidimensional plots are not implemented, '
|
raise NotImplementedError('Multidimensional plots are not implemented, '
|
||||||
'set components to an available components of the problem.')
|
'set components to an available components of'
|
||||||
|
' the problem.')
|
||||||
v = [
|
v = [
|
||||||
var for var in solver.problem.input_variables
|
var for var in solver.problem.input_variables
|
||||||
if var not in fixed_variables.keys()
|
if var not in fixed_variables.keys()
|
||||||
@@ -213,7 +217,8 @@ class Plotter:
|
|||||||
pts = pts.to(device=solver.device)
|
pts = pts.to(device=solver.device)
|
||||||
|
|
||||||
# computing soluting and sending to cpu
|
# computing soluting and sending to cpu
|
||||||
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach()
|
predicted_output = solver.forward(pts).extract(components)
|
||||||
|
predicted_output = predicted_output.as_subclass(torch.Tensor).cpu().detach()
|
||||||
pts = pts.cpu()
|
pts = pts.cpu()
|
||||||
truth_solution = getattr(solver.problem, 'truth_solution', None)
|
truth_solution = getattr(solver.problem, 'truth_solution', None)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user