Merge pull request #22 from dario-coscia/main
Small update in _compute_error() in pinn.py
This commit is contained in:
@@ -77,9 +77,9 @@ class PINN(object):
|
|||||||
:param vec torch.tensor: the tensor
|
:param vec torch.tensor: the tensor
|
||||||
"""
|
"""
|
||||||
if isinstance(self.error_norm, int):
|
if isinstance(self.error_norm, int):
|
||||||
return torch.sum(torch.abs(vec**self.error_norm))**(1./self.error_norm)
|
return torch.linalg.vector_norm(vec, ord = self.error_norm, dtype=self.dytpe)
|
||||||
elif self.error_norm == 'mse':
|
elif self.error_norm == 'mse':
|
||||||
return torch.mean(vec**2)
|
return torch.mean(vec.pow(2))
|
||||||
elif self.error_norm == 'me':
|
elif self.error_norm == 'me':
|
||||||
return torch.mean(torch.abs(vec))
|
return torch.mean(torch.abs(vec))
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user