From c52ee03b534ab7c99ab35a340a14028f7d2f44f7 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 14 Dec 2022 13:56:10 +0100 Subject: [PATCH] solving span bugs (#57) --- pina/pinn.py | 17 ++++++++++------- pina/span.py | 19 +++++++++++++++++++ tests/test_pinn.py | 28 +++++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/pina/pinn.py b/pina/pinn.py index e990da4..afd285a 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -173,7 +173,16 @@ class PINN(object): >>> pinn.span_pts(n=10, mode='grid', location=['bound1']) >>> pinn.span_pts(n=10, mode='grid', variables=['x']) """ - if isinstance(args[0], int) and isinstance(args[1], str): + + if all(key in kwargs for key in ['n', 'mode']): + argument = {} + argument['n'] = kwargs['n'] + argument['mode'] = kwargs['mode'] + argument['variables'] = self.problem.input_variables + arguments = [argument] + elif any(key in kwargs for key in ['n', 'mode']) and args: + raise ValueError("Don't mix args and kwargs") + elif isinstance(args[0], int) and isinstance(args[1], str): argument = {} argument['n'] = int(args[0]) argument['mode'] = args[1] @@ -181,12 +190,6 @@ class PINN(object): arguments = [argument] elif all(isinstance(arg, dict) for arg in args): arguments = args - elif all(key in kwargs for key in ['n', 'mode']): - argument = {} - argument['n'] = kwargs['n'] - argument['mode'] = kwargs['mode'] - argument['variables'] = self.problem.input_variables - arguments = [argument] else: raise RuntimeError diff --git a/pina/span.py b/pina/span.py index 141d5a4..20f5612 100644 --- a/pina/span.py +++ b/pina/span.py @@ -100,6 +100,25 @@ class Span(Location): result = result.append(pts_variable, mode='std') return result + def _single_points_sample(n, variables): + tmp = [] + for variable in variables: + if variable in self.fixed_.keys(): + value = self.fixed_[variable] + pts_variable = torch.tensor([[value]]).repeat(n, 1) + pts_variable = pts_variable.as_subclass(LabelTensor) + pts_variable.labels = [variable] + tmp.append(pts_variable) + + result = tmp[0] + for i in tmp[1:]: + result = result.append(i, mode='std') + + return result + + if self.fixed_ and (not self.range_): + return _single_points_sample(n, variables) + if variables == 'all': variables = list(self.range_.keys()) + list(self.fixed_.keys()) diff --git a/tests/test_pinn.py b/tests/test_pinn.py index 439a146..7c0cd0b 100644 --- a/tests/test_pinn.py +++ b/tests/test_pinn.py @@ -70,6 +70,32 @@ def test_span_pts(): assert pinn.input_pts['D'].shape[0] == n +def test_sampling_all_args(): + pinn = PINN(problem, model) + n = 10 + pinn.span_pts(n, 'grid', locations=['D']) + + +def test_sampling_all_kwargs(): + pinn = PINN(problem, model) + n = 10 + pinn.span_pts(n=n, mode='latin', locations=['D']) + + +def test_sampling_dict(): + pinn = PINN(problem, model) + n = 10 + pinn.span_pts( + {'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D']) + + +def test_sampling_mixed_args_kwargs(): + pinn = PINN(problem, model) + n = 10 + with pytest.raises(ValueError): + pinn.span_pts(n, mode='latin', locations=['D']) + + def test_train(): pinn = PINN(problem, model) boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] @@ -130,4 +156,4 @@ if torch.cuda.is_available(): n = 100 pinn.span_pts(n, 'grid', boundaries) pinn.span_pts(n, 'grid', locations=['D']) - pinn.train(5) \ No newline at end of file + pinn.train(5)