Refactoring code
This commit is contained in:
22
pina/pinn.py
22
pina/pinn.py
@@ -1,12 +1,7 @@
|
||||
from mpmath import chebyt, chop, taylor
|
||||
|
||||
from .problem import Problem
|
||||
from .problem import AbstractProblem
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from .cube import Cube
|
||||
from .segment import Segment
|
||||
from .deep_feed_forward import DeepFeedForward
|
||||
from pina.label_tensor import LabelTensor
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
@@ -86,7 +81,7 @@ class PINN(object):
|
||||
|
||||
@problem.setter
|
||||
def problem(self, problem):
|
||||
if not isinstance(problem, Problem):
|
||||
if not isinstance(problem, AbstractProblem):
|
||||
raise TypeError
|
||||
self._problem = problem
|
||||
|
||||
@@ -157,7 +152,6 @@ class PINN(object):
|
||||
self.trained_epoch = checkpoint['epoch']
|
||||
self.history = checkpoint['history']
|
||||
|
||||
print(self.history)
|
||||
return self
|
||||
|
||||
|
||||
@@ -188,7 +182,6 @@ class PINN(object):
|
||||
def plot_pts(self, locations='all'):
|
||||
import matplotlib
|
||||
matplotlib.use('GTK3Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.problem.conditions]
|
||||
|
||||
@@ -197,7 +190,6 @@ class PINN(object):
|
||||
#plt.plot(x.detach(), y.detach(), 'o', label=location)
|
||||
np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
|
||||
|
||||
gggg
|
||||
|
||||
plt.legend()
|
||||
plt.show()
|
||||
@@ -234,7 +226,7 @@ class PINN(object):
|
||||
|
||||
if trial:
|
||||
import optuna
|
||||
trial.report(loss[0].item()+loss[1].item(), epoch)
|
||||
trial.report(sum(losses), epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
|
||||
@@ -282,6 +274,9 @@ class PINN(object):
|
||||
print("Something went wrong...")
|
||||
print("Not able to compute the error. Please pass a data solution or a true solution")
|
||||
|
||||
|
||||
|
||||
|
||||
def plot(self, res, filename=None, variable=None):
|
||||
'''
|
||||
'''
|
||||
@@ -289,6 +284,9 @@ class PINN(object):
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
self._plot_2D(res, filename, variable)
|
||||
print('TTTTTTTTTTTTTTTTTt')
|
||||
print(self.problem.bounds)
|
||||
pts_container = []
|
||||
#for mn, mx in [[-1, 1], [-1, 1]]:
|
||||
for mn, mx in [[0, 1], [0, 1]]:
|
||||
|
||||
Reference in New Issue
Block a user