diff --git a/pina/pinn.py b/pina/pinn.py index 8a2527b..03929cb 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -77,9 +77,9 @@ class PINN(object): :param vec torch.tensor: the tensor """ if isinstance(self.error_norm, int): - return torch.sum(torch.abs(vec**self.error_norm))**(1./self.error_norm) + return torch.linalg.vector_norm(vec, ord = self.error_norm, dtype=self.dytpe) elif self.error_norm == 'mse': - return torch.mean(vec**2) + return torch.mean(vec.pow(2)) elif self.error_norm == 'me': return torch.mean(torch.abs(vec)) else: