fix old codes
This commit is contained in:
@@ -7,7 +7,7 @@ from torch.nn import ReLU, Tanh, Softplus
|
||||
from pina import PINN, LabelTensor, Plotter
|
||||
from pina.model import FeedForward
|
||||
from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh
|
||||
from problems.poisson import Poisson
|
||||
from poisson2 import Poisson
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
@@ -19,7 +19,9 @@ class myFeature(torch.nn.Module):
|
||||
super(myFeature, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sin(x[:, 0]*torch.pi) * torch.sin(x[:, 1]*torch.pi)
|
||||
t = (torch.sin(x.extract(['x'])*torch.pi) *
|
||||
torch.sin(x.extract(['y'])*torch.pi))
|
||||
return LabelTensor(t, ['sin(x)sin(y)'])
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -51,14 +53,9 @@ if __name__ == "__main__":
|
||||
|
||||
if args.s:
|
||||
|
||||
print(pinn)
|
||||
pinn.span_pts(20, mode_spatial='grid', locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||
pinn.span_pts(20, mode_spatial='grid', locations=['D'])
|
||||
pinn.plot_pts()
|
||||
pinn.span_pts(20, 'grid', locations=['gamma1', 'gamma2', 'gamma3', 'gamma4'])
|
||||
pinn.span_pts(20, 'grid', locations=['D'])
|
||||
pinn.train(5000, 100)
|
||||
with open('poisson_history_{}_{}.txt'.format(args.id_run, args.features), 'w') as file_:
|
||||
for i, losses in enumerate(pinn.history):
|
||||
file_.write('{} {}\n'.format(i, sum(losses)))
|
||||
pinn.save_state('pina.poisson')
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user