fix codacy

This commit is contained in:
Dario Coscia
2024-01-31 12:48:36 +01:00
committed by Nicola Demo
parent 513144dfaf
commit b38c640c4d

View File

@@ -15,7 +15,8 @@ class Plotter:
""" """
Plot the training grid samples. Plot the training grid samples.
:param AbstractProblem problem: The PINA problem from where to plot the domain. :param AbstractProblem problem: The PINA problem from where to plot
the domain.
:param list(str) variables: Variables to plot. If None, all variables :param list(str) variables: Variables to plot. If None, all variables
are plotted. If 'spatial', only spatial variables are plotted. If are plotted. If 'spatial', only spatial variables are plotted. If
'temporal', only temporal variables are plotted. Defaults to None. 'temporal', only temporal variables are plotted. Defaults to None.
@@ -39,7 +40,8 @@ class Plotter:
variables = problem.temporal_domain.variables variables = problem.temporal_domain.variables
if len(variables) not in [1, 2, 3]: if len(variables) not in [1, 2, 3]:
raise ValueError('Samples can be plotted only in dimensions 1, 2 and 3.') raise ValueError('Samples can be plotted only in '
'dimensions 1, 2 and 3.')
fig = plt.figure() fig = plt.figure()
proj = '3d' if len(variables) == 3 else None proj = '3d' if len(variables) == 3 else None
@@ -96,7 +98,8 @@ class Plotter:
if truth_solution: if truth_solution:
truth_output = truth_solution(pts).detach() truth_output = truth_solution(pts).detach()
ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs) ax.plot(pts.extract(v), truth_output,
label='True solution', **kwargs)
# TODO: pred is a torch.Tensor, so no labels is available # TODO: pred is a torch.Tensor, so no labels is available
# extra variable for labels should be # extra variable for labels should be
@@ -197,7 +200,8 @@ class Plotter:
if len(components) > 1: if len(components) > 1:
raise NotImplementedError('Multidimensional plots are not implemented, ' raise NotImplementedError('Multidimensional plots are not implemented, '
'set components to an available components of the problem.') 'set components to an available components of'
' the problem.')
v = [ v = [
var for var in solver.problem.input_variables var for var in solver.problem.input_variables
if var not in fixed_variables.keys() if var not in fixed_variables.keys()
@@ -213,7 +217,8 @@ class Plotter:
pts = pts.to(device=solver.device) pts = pts.to(device=solver.device)
# computing soluting and sending to cpu # computing soluting and sending to cpu
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach() predicted_output = solver.forward(pts).extract(components)
predicted_output = predicted_output.as_subclass(torch.Tensor).cpu().detach()
pts = pts.cpu() pts = pts.cpu()
truth_solution = getattr(solver.problem, 'truth_solution', None) truth_solution = getattr(solver.problem, 'truth_solution', None)