diff --git a/examples/problems/stokes.py b/examples/problems/stokes.py new file mode 100644 index 0000000..9874f29 --- /dev/null +++ b/examples/problems/stokes.py @@ -0,0 +1,45 @@ +import numpy as np +import torch + +from pina.problem import SpatialProblem +from pina.operators import nabla, grad, div +from pina import Condition, Span, LabelTensor + + +class Stokes(SpatialProblem): + + spatial_variables = ['x', 'y'] + output_variables = ['ux', 'uy', 'p'] + domain = Span({'x': [-2, 2], 'y': [-1, 1]}) + + def momentum(input_, output_): + #print(nabla(output_['ux', 'uy'], input_)) + #print(grad(output_['p'], input_)) + nabla_ = LabelTensor.hstack([ + LabelTensor(nabla(output_['ux'], input_), ['x']), + LabelTensor(nabla(output_['uy'], input_), ['y'])]) + #return LabelTensor(nabla_.tensor + grad(output_['p'], input_).tensor, ['x', 'y']) + return nabla_.tensor + grad(output_['p'], input_).tensor + + def continuity(input_, output_): + return div(output_['ux', 'uy'], input_) + + def inlet(input_, output_): + value = 2.0 + return output_['ux'] - value + + def outlet(input_, output_): + value = 0.0 + return output_['p'] - value + + def wall(input_, output_): + value = 0.0 + return output_['ux', 'uy'].tensor - value + + conditions = { + 'gamma_top': Condition(Span({'x': [-2, 2], 'y': 1}), wall), + 'gamma_bot': Condition(Span({'x': [-2, 2], 'y': -1}), wall), + 'gamma_out': Condition(Span({'x': 2, 'y': [-1, 1]}), outlet), + 'gamma_in': Condition(Span({'x': -2, 'y': [-1, 1]}), inlet), + 'D': Condition(Span({'x': [-2, 2], 'y': [-1, 1]}), [momentum, continuity]), + } diff --git a/examples/run_stokes.py b/examples/run_stokes.py new file mode 100644 index 0000000..e6438cd --- /dev/null +++ b/examples/run_stokes.py @@ -0,0 +1,54 @@ +import argparse +import sys +import numpy as np +import torch +from torch.nn import ReLU, Tanh, Softplus + +from pina import PINN, LabelTensor, Plotter +from pina.model import FeedForward +from pina.adaptive_functions import AdaptiveSin, AdaptiveCos, AdaptiveTanh +from problems.stokes import Stokes + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Run PINA") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("-s", "-save", action="store_true") + group.add_argument("-l", "-load", action="store_true") + parser.add_argument("id_run", help="number of run", type=int) + args = parser.parse_args() + + + stokes_problem = Stokes() + model = FeedForward( + layers=[40, 20, 20, 10], + output_variables=stokes_problem.output_variables, + input_variables=stokes_problem.input_variables, + func=Softplus, + ) + + pinn = PINN( + stokes_problem, + model, + lr=0.006, + error_norm='mse', + regularizer=1e-8, + lr_accelerate=None) + + if args.s: + + #pinn.span_pts(200, 'grid', ['gamma_out']) + pinn.span_pts(200, 'grid', ['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out']) + pinn.span_pts(2000, 'random', ['D']) + #plotter = Plotter() + #plotter.plot_samples(pinn) + pinn.train(10000, 100) + pinn.save_state('pina.stokes') + + else: + pinn.load_state('pina.stokes') + plotter = Plotter() + plotter.plot_samples(pinn) + plotter.plot(pinn) + + diff --git a/pina/condition.py b/pina/condition.py index 398c387..27eb42b 100644 --- a/pina/condition.py +++ b/pina/condition.py @@ -15,6 +15,9 @@ class Condition: elif isinstance(args[0], Location) and callable(args[1]): self.location = args[0] self.function = args[1] + elif isinstance(args[0], Location) and isinstance(args[1], list): + self.location = args[0] + self.function = args[1] else: raise ValueError diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 16eec26..c9e219b 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -12,6 +12,9 @@ class LabelTensor(): self.tensor = x def __getitem__(self, key): + if isinstance(key, (tuple, list)): + indeces = [self.labels.index(k) for k in key] + return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces]) if key in self.labels: return self.tensor[:, self.labels.index(key)] else: diff --git a/pina/model/feed_forward.py b/pina/model/feed_forward.py index 6095a97..96e82bb 100644 --- a/pina/model/feed_forward.py +++ b/pina/model/feed_forward.py @@ -52,6 +52,8 @@ class FeedForward(torch.nn.Module): def forward(self, x): """ """ + + x = x[self.input_variables] nf = len(self.extra_features) if nf == 0: return LabelTensor(self.model(x.tensor), self.output_variables) diff --git a/pina/pinn.py b/pina/pinn.py index 916f2be..fc3d233 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -65,8 +65,8 @@ class PINN(object): self.model = model self.model.to(dtype=self.dtype, device=self.device) - self.input_pts = {} self.truth_values = {} + self.input_pts = {} self.trained_epoch = 0 @@ -171,13 +171,15 @@ class PINN(object): except: pts = condition.input_points + print(location, pts) + self.input_pts[location] = pts - print(pts.tensor.shape) self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device) self.input_pts[location].tensor.requires_grad_(True) self.input_pts[location].tensor.retain_grad() + def plot_pts(self, locations='all'): import matplotlib matplotlib.use('GTK3Agg') @@ -209,8 +211,13 @@ class PINN(object): predicted = self.model(pts) - residuals = condition.function(pts, predicted) - losses.append(self._compute_norm(residuals)) + if isinstance(condition.function, list): + for function in condition.function: + residuals = function(pts, predicted) + losses.append(self._compute_norm(residuals)) + else: + residuals = condition.function(pts, predicted) + losses.append(self._compute_norm(residuals)) self.optimizer.zero_grad() sum(losses).backward() diff --git a/pina/plotter.py b/pina/plotter.py index 15add45..c70d956 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -84,11 +84,13 @@ class Plotter: """ res = 256 pts = obj.problem.domain.sample(res, 'grid') + print(pts) grids_container = [ - pts[:, 0].reshape(res, res), - pts[:, 1].reshape(res, res), + pts.tensor[:, 0].reshape(res, res), + pts.tensor[:, 1].reshape(res, res), ] predicted_output = obj.model(pts) + predicted_output = predicted_output['p'] if hasattr(obj.problem, 'truth_solution'): truth_output = obj.problem.truth_solution(*pts.tensor.T).float() @@ -102,10 +104,56 @@ class Plotter: fig.colorbar(cb, ax=axes[2]) else: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) - cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) + # cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) + cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach()) fig.colorbar(cb, ax=axes) if filename: plt.savefig(filename) else: plt.show() + + + def plot(self, obj, method='contourf', filename=None): + """ + """ + res = 256 + pts = obj.problem.domain.sample(res, 'grid') + print(pts) + grids_container = [ + pts.tensor[:, 0].reshape(res, res), + pts.tensor[:, 1].reshape(res, res), + ] + predicted_output = obj.model(pts) + predicted_output = predicted_output['ux'] + + if hasattr(obj.problem, 'truth_solution'): + truth_output = obj.problem.truth_solution(*pts.tensor.T).float() + fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) + + cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) + fig.colorbar(cb, ax=axes[0]) + cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach()) + fig.colorbar(cb, ax=axes[1]) + cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res)) + fig.colorbar(cb, ax=axes[2]) + else: + fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) + # cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach()) + cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach()) + fig.colorbar(cb, ax=axes) + + if filename: + plt.savefig(filename) + else: + plt.show() + + + + def plot_samples(self, obj): + + for location in obj.input_pts: + plt.plot(*obj.input_pts[location].tensor.T.detach(), '.', label=location) + + plt.legend() + plt.show() diff --git a/pina/span.py b/pina/span.py index c8f33a6..f3c55e5 100644 --- a/pina/span.py +++ b/pina/span.py @@ -37,6 +37,7 @@ class Span(Location): for _ in range(bounds.shape[0])]) grids = np.meshgrid(*pts) pts = np.hstack([grid.reshape(-1, 1) for grid in grids]) + print(pts) elif mode == 'lh' or mode == 'latin': from scipy.stats import qmc sampler = qmc.LatinHypercube(d=bounds.shape[0])