Update pinn.py

This commit is contained in:
Dario Coscia
2022-07-21 16:12:21 +02:00
committed by GitHub
parent e8c2f87460
commit 9c3f94d3ec

View File

@@ -77,9 +77,9 @@ class PINN(object):
:param vec torch.tensor: the tensor
"""
if isinstance(self.error_norm, int):
return torch.sum(torch.abs(vec**self.error_norm))**(1./self.error_norm)
return torch.linalg.vector_norm(vec, ord = self.error_norm, dtype=self.dytpe)
elif self.error_norm == 'mse':
return torch.mean(vec**2)
return torch.mean(vec.pow(2))
elif self.error_norm == 'me':
return torch.mean(torch.abs(vec))
else: