Merge pull request #22 from dario-coscia/main

Small update in _compute_error() in pinn.py
This commit is contained in:
Nicola Demo
2022-07-21 17:02:59 +02:00
committed by GitHub

View File

@@ -77,9 +77,9 @@ class PINN(object):
:param vec torch.tensor: the tensor :param vec torch.tensor: the tensor
""" """
if isinstance(self.error_norm, int): if isinstance(self.error_norm, int):
return torch.sum(torch.abs(vec**self.error_norm))**(1./self.error_norm) return torch.linalg.vector_norm(vec, ord = self.error_norm, dtype=self.dytpe)
elif self.error_norm == 'mse': elif self.error_norm == 'mse':
return torch.mean(vec**2) return torch.mean(vec.pow(2))
elif self.error_norm == 'me': elif self.error_norm == 'me':
return torch.mean(torch.abs(vec)) return torch.mean(torch.abs(vec))
else: else: