Refactoring code

This commit is contained in:
Your Name
2022-01-27 14:55:42 +01:00
parent fb16fc7f3a
commit fa8ffd5042
32 changed files with 417 additions and 442 deletions

View File

@@ -1,12 +1,7 @@
from mpmath import chebyt, chop, taylor
from .problem import Problem
from .problem import AbstractProblem
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from .cube import Cube
from .segment import Segment
from .deep_feed_forward import DeepFeedForward
from pina.label_tensor import LabelTensor
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -86,7 +81,7 @@ class PINN(object):
@problem.setter
def problem(self, problem):
if not isinstance(problem, Problem):
if not isinstance(problem, AbstractProblem):
raise TypeError
self._problem = problem
@@ -157,7 +152,6 @@ class PINN(object):
self.trained_epoch = checkpoint['epoch']
self.history = checkpoint['history']
print(self.history)
return self
@@ -188,7 +182,6 @@ class PINN(object):
def plot_pts(self, locations='all'):
import matplotlib
matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
if locations == 'all':
locations = [condition for condition in self.problem.conditions]
@@ -197,7 +190,6 @@ class PINN(object):
#plt.plot(x.detach(), y.detach(), 'o', label=location)
np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
gggg
plt.legend()
plt.show()
@@ -234,7 +226,7 @@ class PINN(object):
if trial:
import optuna
trial.report(loss[0].item()+loss[1].item(), epoch)
trial.report(sum(losses), epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
@@ -282,6 +274,9 @@ class PINN(object):
print("Something went wrong...")
print("Not able to compute the error. Please pass a data solution or a true solution")
def plot(self, res, filename=None, variable=None):
'''
'''
@@ -289,6 +284,9 @@ class PINN(object):
matplotlib.use('Agg')
import matplotlib.pyplot as plt
self._plot_2D(res, filename, variable)
print('TTTTTTTTTTTTTTTTTt')
print(self.problem.bounds)
pts_container = []
#for mn, mx in [[-1, 1], [-1, 1]]:
for mn, mx in [[0, 1], [0, 1]]: