add docs
This commit is contained in:
68
pina/pinn.py
68
pina/pinn.py
@@ -88,10 +88,13 @@ class PINN(object):
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
""" The problem formulation."""
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, problem):
|
||||
"""
|
||||
Set the problem formulation."""
|
||||
if not isinstance(problem, AbstractProblem):
|
||||
raise TypeError
|
||||
self._problem = problem
|
||||
@@ -99,11 +102,11 @@ class PINN(object):
|
||||
def _compute_norm(self, vec):
|
||||
"""
|
||||
Compute the norm of the `vec` one-dimensional tensor based on the
|
||||
`self.error_norm` attribute.
|
||||
`self.error_norm` attribute.
|
||||
|
||||
.. todo: complete
|
||||
|
||||
:param vec torch.tensor: the tensor
|
||||
:param torch.Tensor vec: the tensor
|
||||
"""
|
||||
if isinstance(self.error_norm, int):
|
||||
return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe)
|
||||
@@ -115,7 +118,11 @@ class PINN(object):
|
||||
raise RuntimeError
|
||||
|
||||
def save_state(self, filename):
|
||||
"""
|
||||
Save the state of the model.
|
||||
|
||||
:param str filename: the filename to save the state to.
|
||||
"""
|
||||
checkpoint = {
|
||||
'epoch': self.trained_epoch,
|
||||
'model_state': self.model.state_dict(),
|
||||
@@ -133,6 +140,11 @@ class PINN(object):
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
def load_state(self, filename):
|
||||
"""
|
||||
Load the state of the model.
|
||||
|
||||
:param str filename: the filename to load the state from.
|
||||
"""
|
||||
|
||||
checkpoint = torch.load(filename)
|
||||
self.model.load_state_dict(checkpoint['model_state'])
|
||||
@@ -298,32 +310,32 @@ class PINN(object):
|
||||
|
||||
return sum(losses).item()
|
||||
|
||||
def error(self, dtype='l2', res=100):
|
||||
# def error(self, dtype='l2', res=100):
|
||||
|
||||
import numpy as np
|
||||
if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
|
||||
pts_container = []
|
||||
for mn, mx in self.problem.domain_bound:
|
||||
pts_container.append(np.linspace(mn, mx, res))
|
||||
grids_container = np.meshgrid(*pts_container)
|
||||
Z_true = self.problem.truth_solution(*grids_container)
|
||||
# import numpy as np
|
||||
# if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None:
|
||||
# pts_container = []
|
||||
# for mn, mx in self.problem.domain_bound:
|
||||
# pts_container.append(np.linspace(mn, mx, res))
|
||||
# grids_container = np.meshgrid(*pts_container)
|
||||
# Z_true = self.problem.truth_solution(*grids_container)
|
||||
|
||||
elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
|
||||
grids_container = self.problem.data_solution['grid']
|
||||
Z_true = self.problem.data_solution['grid_solution']
|
||||
try:
|
||||
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
|
||||
dtype=self.dtype, device=self.device)
|
||||
Z_pred = self.model(unrolled_pts)
|
||||
Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
|
||||
# elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None:
|
||||
# grids_container = self.problem.data_solution['grid']
|
||||
# Z_true = self.problem.data_solution['grid_solution']
|
||||
# try:
|
||||
# unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
|
||||
# dtype=self.dtype, device=self.device)
|
||||
# Z_pred = self.model(unrolled_pts)
|
||||
# Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
|
||||
|
||||
if dtype == 'l2':
|
||||
return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true)
|
||||
else:
|
||||
# TODO H1
|
||||
pass
|
||||
except:
|
||||
print("")
|
||||
print("Something went wrong...")
|
||||
print(
|
||||
"Not able to compute the error. Please pass a data solution or a true solution")
|
||||
# if dtype == 'l2':
|
||||
# return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true)
|
||||
# else:
|
||||
# # TODO H1
|
||||
# pass
|
||||
# except:
|
||||
# print("")
|
||||
# print("Something went wrong...")
|
||||
# print(
|
||||
# "Not able to compute the error. Please pass a data solution or a true solution")
|
||||
|
||||
Reference in New Issue
Block a user