From 736c78fd64f0e11defb3eecbca9096572a14eaf9 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 18 Apr 2023 10:49:57 +0200 Subject: [PATCH] add docs --- docs/Makefile | 2 +- docs/source/_rst/code.rst | 4 +- docs/source/_rst/condition.rst | 12 +++ docs/source/_rst/deeponet.rst | 8 +- docs/source/_rst/fnn.rst | 8 +- docs/source/_rst/label_tensor.rst | 8 +- docs/source/_rst/pinn.rst | 8 +- docs/source/conf.py | 18 ++--- docs/source/index.rst | 4 +- examples/run_poisson_deeponet.py | 124 +++++++++++++++--------------- pina/condition.py | 113 ++++++++++++++++++--------- pina/location.py | 10 ++- pina/model/deeponet.py | 26 +++++-- pina/model/feed_forward.py | 4 +- pina/model/multi_feed_forward.py | 5 +- pina/pinn.py | 68 +++++++++------- tests/test_condition.py | 42 ++++++++++ 17 files changed, 292 insertions(+), 172 deletions(-) create mode 100644 docs/source/_rst/condition.rst create mode 100644 tests/test_condition.py diff --git a/docs/Makefile b/docs/Makefile index ed2201d..7985d8c 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line. SPHINXOPTS = -SPHINXBUILD = sphinx-build +SPHINXBUILD = python -msphinx #sphinx-build PAPER = BUILDDIR = build diff --git a/docs/source/_rst/code.rst b/docs/source/_rst/code.rst index a74dc26..feb61db 100644 --- a/docs/source/_rst/code.rst +++ b/docs/source/_rst/code.rst @@ -15,4 +15,6 @@ Code Documentation SpatialProblem TimeDependentProblem Operators - Plotter \ No newline at end of file + Plotter + PINN + Condition diff --git a/docs/source/_rst/condition.rst b/docs/source/_rst/condition.rst new file mode 100644 index 0000000..fe834b7 --- /dev/null +++ b/docs/source/_rst/condition.rst @@ -0,0 +1,12 @@ +Condition +========= +.. currentmodule:: pina.condition + +.. automodule:: pina.condition + +.. autoclass:: Condition + :members: + :private-members: + :undoc-members: + :show-inheritance: + :noindex: diff --git a/docs/source/_rst/deeponet.rst b/docs/source/_rst/deeponet.rst index 34f3661..94bd6e9 100644 --- a/docs/source/_rst/deeponet.rst +++ b/docs/source/_rst/deeponet.rst @@ -5,8 +5,6 @@ DeepONet .. automodule:: pina.model.deeponet .. autoclass:: DeepONet - :members: - :private-members: - :undoc-members: - :show-inheritance: - :noindex: + :members: + :private-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/fnn.rst b/docs/source/_rst/fnn.rst index 3ddfaf1..98bec0c 100644 --- a/docs/source/_rst/fnn.rst +++ b/docs/source/_rst/fnn.rst @@ -5,8 +5,6 @@ FeedForward .. automodule:: pina.model.feed_forward .. autoclass:: FeedForward - :members: - :private-members: - :undoc-members: - :show-inheritance: - :noindex: + :members: + :private-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/label_tensor.rst b/docs/source/_rst/label_tensor.rst index 976e839..07a665e 100644 --- a/docs/source/_rst/label_tensor.rst +++ b/docs/source/_rst/label_tensor.rst @@ -5,8 +5,6 @@ LabelTensor .. automodule:: pina.label_tensor .. autoclass:: LabelTensor - :members: - :private-members: - :undoc-members: - :show-inheritance: - :noindex: + :members: + :private-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/pinn.rst b/docs/source/_rst/pinn.rst index 4cf03af..ff85342 100644 --- a/docs/source/_rst/pinn.rst +++ b/docs/source/_rst/pinn.rst @@ -5,8 +5,6 @@ PINN .. automodule:: pina.pinn .. autoclass:: PINN - :members: - :private-members: - :undoc-members: - :show-inheritance: - :noindex: + :members: + :private-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 766b7d6..df7a1a8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -36,28 +36,28 @@ import pina # ones. extensions = [ 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.coverage', - 'sphinx.ext.graphviz', - 'sphinx.ext.doctest', + #'sphinx.ext.autosummary', + #'sphinx.ext.coverage', + #'sphinx.ext.graphviz', + #'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.todo', - 'sphinx.ext.coverage', + #'sphinx.ext.coverage', 'sphinx.ext.viewcode', - 'sphinx.ext.ifconfig', + #'sphinx.ext.ifconfig', 'sphinx.ext.mathjax', ] -autosummary_generate = True +#autosummary_generate = True intersphinx_mapping = { - 'python': ('http://docs.python.org/2', None), + 'python': ('http://docs.python.org/3', None), 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), 'matplotlib': ('http://matplotlib.sourceforge.net/', None), 'torch': ('https://pytorch.org/docs/stable/', None), - 'pina': ('https://mathlab.github.io/PINA/', None) } +nitpicky = True # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/docs/source/index.rst b/docs/source/index.rst index 52fa72b..f911de7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,11 +14,11 @@ learning tasks while respecting any given law of physics described by general nonlinear differential equations. Proposed in "Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations", such framework aims to -solve problems in a continuous and nonlinear settings. +solve problems in a continuous and nonlinear settings. :py:class:`pina.pinn.PINN` .. toctree:: - :maxdepth: 1 + :maxdepth: 2 :caption: Package Documentation: Installation <_rst/installation> diff --git a/examples/run_poisson_deeponet.py b/examples/run_poisson_deeponet.py index 4f5e69d..d1a9891 100644 --- a/examples/run_poisson_deeponet.py +++ b/examples/run_poisson_deeponet.py @@ -5,11 +5,7 @@ import torch from problems.poisson import Poisson from pina import PINN, LabelTensor, Plotter -from pina.model.deeponet import DeepONet, check_combos, spawn_combo_networks - -logging.basicConfig( - filename="poisson_deeponet.log", filemode="w", level=logging.INFO -) +from pina.model import DeepONet, FeedForward class SinFeature(torch.nn.Module): @@ -36,27 +32,33 @@ class SinFeature(torch.nn.Module): return LabelTensor(t, [f"sin({self._label})"]) -def prepare_deeponet_model(args, problem, extra_feature_combo_func=None): - combos = tuple(map(lambda combo: combo.split("-"), args.combos.split(","))) - check_combos(combos, problem.input_variables) +class myRBF(torch.nn.Module): + def __init__(self, input_): - extra_feature = extra_feature_combo_func if args.extra else None - networks = spawn_combo_networks( - combos=combos, - layers=list(map(int, args.layers.split(","))) if args.layers else [], - output_dimension=args.hidden * len(problem.output_variables), - func=torch.nn.Softplus, - extra_feature=extra_feature, - bias=not args.nobias, - ) + super().__init__() - return DeepONet( - networks, - problem.output_variables, - aggregator=args.aggregator, - reduction=args.reduction, - ) + self.input_variables = [input_] + self.a = torch.nn.Parameter(torch.tensor([-.3])) + # self.b = torch.nn.Parameter(torch.tensor([0.5])) + self.b = torch.tensor([0.5]) + self.c = torch.nn.Parameter(torch.tensor([.5])) + def forward(self, x): + x = x.extract(self.input_variables) + result = self.a * torch.exp(-(x - self.b)**2/(self.c**2)) + return result + +class myModel(torch.nn.Module): + def __init__(self): + + super().__init__() + self.ffn_x = myRBF('x') + self.ffn_y = myRBF('y') + + def forward(self, x): + result = self.ffn_x(x) * self.ffn_y(x) + result.labels = ['u'] + return result if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run PINA") @@ -65,49 +67,49 @@ if __name__ == "__main__": parser.add_argument("id_run", help="Run ID", type=int) parser.add_argument("--extra", help="Extra features", action="store_true") - parser.add_argument("--nobias", action="store_true") - parser.add_argument( - "--combos", - help="DeepONet internal network combinations", - type=str, - required=True, - ) - parser.add_argument( - "--aggregator", help="Aggregator for DeepONet", type=str, default="*" - ) - parser.add_argument( - "--reduction", help="Reduction for DeepONet", type=str, default="+" - ) - parser.add_argument( - "--hidden", - help="Number of variables in the hidden DeepONet layer", - type=int, - required=True, - ) - parser.add_argument( - "--layers", - help="Structure of the DeepONet partial layers", - type=str, - required=True, - ) - cli_args = parser.parse_args() + args = parser.parse_args() - poisson_problem = Poisson() + problem = Poisson() - model = prepare_deeponet_model( - cli_args, - poisson_problem, - extra_feature_combo_func=lambda combo: [SinFeature(combo)], - ) - pinn = PINN(poisson_problem, model, lr=0.01, regularizer=1e-8) - if cli_args.save: + # ffn_x = FeedForward( + # input_variables=['x'], layers=[], output_variables=1, + # func=torch.nn.Softplus, + # extra_features=[SinFeature('x')] + # ) + # ffn_y = FeedForward + # input_variables=['y'], layers=[], output_variables=1, + # func=torch.nn.Softplus, + # extra_features=[SinFeature('y')] + # ) + model = myModel() + test = torch.tensor([[0.0, 0.5]]) + test.labels = ['x', 'y'] + pinn = PINN(problem, model, lr=0.0001) + + if args.save: pinn.span_pts( 20, "grid", locations=["gamma1", "gamma2", "gamma3", "gamma4"] ) pinn.span_pts(20, "grid", locations=["D"]) - pinn.train(1.0e-10, 100) - pinn.save_state(f"pina.poisson_{cli_args.id_run}") - if cli_args.load: - pinn.load_state(f"pina.poisson_{cli_args.id_run}") + while True: + pinn.train(500, 50) + print(model.ffn_x.a) + print(model.ffn_x.b) + print(model.ffn_x.c) + + xi = torch.linspace(0, 1, 64).reshape(-1, 1).as_subclass(LabelTensor) + xi.labels = ['x'] + yi = model.ffn_x(xi) + y_truth = -torch.sin(xi*torch.pi) + + import matplotlib.pyplot as plt + plt.plot(xi.detach().flatten(), yi.detach().flatten(), 'r-') + plt.plot(xi.detach().flatten(), y_truth.detach().flatten(), 'b-') + plt.plot(xi.detach().flatten(), -y_truth.detach().flatten(), 'b-') + plt.show() + pinn.save_state(f"pina.poisson_{args.id_run}") + + if args.load: + pinn.load_state(f"pina.poisson_{args.id_run}") plotter = Plotter() plotter.plot(pinn) diff --git a/pina/condition.py b/pina/condition.py index 09c4399..e9f0679 100644 --- a/pina/condition.py +++ b/pina/condition.py @@ -1,43 +1,86 @@ -""" """ -import torch - +""" Condition module. """ +from .label_tensor import LabelTensor from .location import Location +def dummy(a): + """Dummy function for testing purposes.""" + return None + class Condition: + """ + The class `Condition` is used to represent the constraints (physical + equations, boundary conditions, etc.) that should be satisfied in the + problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.Abstract_Problem` object. + Conditions can be specified in three ways: + + 1. By specifying the input and output points of the condition; in such a + case, the model is trained to produce the output points given the input + points. + + 2. By specifying the location and the function of the condition; in such + a case, the model is trained to minimize that function by evaluating it + at some samples of the location. + + 3. By specifying the input points and the function of the condition; in + such a case, the model is trained to minimize that function by + evaluating it at the input points. + + Example:: + + >>> example_domain = Span({'x': [0, 1], 'y': [0, 1]}) + >>> def example_dirichlet(input_, output_): + >>> value = 0.0 + >>> return output_.extract(['u']) - value + >>> example_input_pts = LabelTensor( + >>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z']) + >>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b']) + >>> + >>> Condition( + >>> input_points=example_input_pts, + >>> output_points=example_output_pts) + >>> Condition( + >>> location=example_domain, + >>> function=example_dirichlet) + >>> Condition( + >>> input_points=example_input_pts, + >>> function=example_dirichlet) + + """ + + __slots__ = [ + 'input_points', 'output_points', 'location', 'function', + 'data_weight' + ] + + def _dictvalue_isinstance(self, dict_, key_, class_): + """Check if the value of a dictionary corresponding to `key` is an instance of `class_`.""" + if key_ not in dict_.keys(): + return True + + return isinstance(dict_[key_], class_) + def __init__(self, *args, **kwargs): + """ + Constructor for the `Condition` class. + """ + self.data_weight = kwargs.pop('data_weight', 1.0) - if 'data_weight' in kwargs: - self.data_weight = kwargs['data_weight'] - if not 'data_weight' in kwargs: - self.data_weight = 1. + if len(args) != 0: + raise ValueError('Condition takes only the following keyword arguments: {`input_points`, `output_points`, `location`, `function`, `data_weight`}.') - if len(args) == 2: + if ( + sorted(kwargs.keys()) != sorted(['input_points', 'output_points']) and + sorted(kwargs.keys()) != sorted(['location', 'function']) and + sorted(kwargs.keys()) != sorted(['input_points', 'function']) + ): + raise ValueError(f'Invalid keyword arguments {kwargs.keys()}.') - if (isinstance(args[0], torch.Tensor) and - isinstance(args[1], torch.Tensor)): - self.input_points = args[0] - self.output_points = args[1] - elif isinstance(args[0], Location) and callable(args[1]): - self.location = args[0] - self.function = args[1] - elif isinstance(args[0], Location) and isinstance(args[1], list): - self.location = args[0] - self.function = args[1] - else: - raise ValueError + if not self._dictvalue_isinstance(kwargs, 'input_points', LabelTensor): + raise TypeError('`input_points` must be a torch.Tensor.') + if not self._dictvalue_isinstance(kwargs, 'output_points', LabelTensor): + raise TypeError('`output_points` must be a torch.Tensor.') + if not self._dictvalue_isinstance(kwargs, 'location', Location): + raise TypeError('`location` must be a Location.') - elif not args and len(kwargs) >= 2: - - if 'input_points' in kwargs and 'output_points' in kwargs: - self.input_points = kwargs['input_points'] - self.output_points = kwargs['output_points'] - elif 'location' in kwargs and 'function' in kwargs: - self.location = kwargs['location'] - self.function = kwargs['function'] - else: - raise ValueError - else: - raise ValueError - - if hasattr(self, 'function') and not isinstance(self.function, list): - self.function = [self.function] + for key, value in kwargs.items(): + setattr(self, key, value) \ No newline at end of file diff --git a/pina/location.py b/pina/location.py index cbd42e4..7d69504 100644 --- a/pina/location.py +++ b/pina/location.py @@ -5,10 +5,14 @@ from abc import ABCMeta, abstractmethod class Location(metaclass=ABCMeta): """ - Abstract class + Abstract Location class. + Any geometry entity should inherit from this class. """ - @property @abstractmethod def sample(self): - pass + """ + Abstract method for sampling a point from the location. To be + implemented in the child class. + """ + pass \ No newline at end of file diff --git a/pina/model/deeponet.py b/pina/model/deeponet.py index f956248..3bedd82 100644 --- a/pina/model/deeponet.py +++ b/pina/model/deeponet.py @@ -12,7 +12,7 @@ from pina.utils import is_function def check_combos(combos, variables): """ - Check that the given combinations are subsets (overlapping + Check that the given combinations are subsets (overlapping is allowed) of the given set of variables. :param iterable(iterable(str)) combos: Combinations of variables. @@ -35,7 +35,7 @@ def spawn_combo_networks( :param iterable(iterable(str)) combos: Combinations of variables. :param iterable(int) layers: Size of hidden layers. :param int output_dimension: Size of the output layer of the networks. - :param func: Nonlinearity. + :param func: Nonlinearity. :param extra_feature: Extra feature to be considered by the networks. :param bool bias: Whether to consider bias or not. """ @@ -78,15 +78,16 @@ class DeepONet(torch.nn.Module): :param list(str) output_variables: the list containing the labels corresponding to the components of the output computed by the model. - :param string | callable aggregator: Aggregator to be used to aggregate + :param str | callable aggregator: Aggregator to be used to aggregate partial results from the modules in `nets`. Partial results are - aggregated component-wise. See :func:`_symbol_functions` for the + aggregated component-wise. See + :func:`pina.model.deeponet.DeepONet._symbol_functions` for the available default aggregators. - :param string | callable reduction: Reduction to be used to reduce + :param str | callable reduction: Reduction to be used to reduce the aggregated result of the modules in `nets` to the desired output - dimension. See :func:`_symbol_functions` for the available default - reductions. - + dimension. See :py:obj:`pina.model.deeponet.DeepONet._symbol_functions` for the available default + reductions. + :Example: >>> branch = FFN(input_variables=['a', 'c'], output_variables=20) >>> trunk = FFN(input_variables=['b'], output_variables=20) @@ -127,9 +128,15 @@ class DeepONet(torch.nn.Module): raise ValueError("All networks should have the same output size") self._nets = torch.nn.ModuleList(nets) logging.info("Combo DeepONet children: %s", list(self.children())) + self.scale = torch.nn.Parameter(torch.tensor([1.0])) + self.trasl = torch.nn.Parameter(torch.tensor([0.0])) @staticmethod def _symbol_functions(**kwargs): + """ + Return a dictionary of functions that can be used as aggregators or + reductions. + """ return { "+": partial(torch.sum, **kwargs), "*": partial(torch.prod, **kwargs), @@ -215,4 +222,7 @@ class DeepONet(torch.nn.Module): output_ = output_.as_subclass(LabelTensor) output_.labels = self.output_variables + + output_ *= self.scale + output_ += self.trasl return output_ diff --git a/pina/model/feed_forward.py b/pina/model/feed_forward.py index 74a8e01..04299cc 100644 --- a/pina/model/feed_forward.py +++ b/pina/model/feed_forward.py @@ -89,8 +89,8 @@ class FeedForward(torch.nn.Module): """ Defines the computation performed at every call. - :param x: the input tensor. - :type x: :class:`pina.LabelTensor` + :param x: . + :type x: :class:`pina.LabelTensor` :return: the output computed by the model. :rtype: LabelTensor """ diff --git a/pina/model/multi_feed_forward.py b/pina/model/multi_feed_forward.py index 984647e..d8d8faa 100644 --- a/pina/model/multi_feed_forward.py +++ b/pina/model/multi_feed_forward.py @@ -9,8 +9,9 @@ class MultiFeedForward(torch.nn.Module): :param dict dff_dict: dictionary of FeedForward networks. """ def __init__(self, dff_dict): - """ - """ + ''' + dff_dict: dict of FeedForward objects + ''' super().__init__() if not isinstance(dff_dict, dict): diff --git a/pina/pinn.py b/pina/pinn.py index 145cd26..f94cab8 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -88,10 +88,13 @@ class PINN(object): @property def problem(self): + """ The problem formulation.""" return self._problem @problem.setter def problem(self, problem): + """ + Set the problem formulation.""" if not isinstance(problem, AbstractProblem): raise TypeError self._problem = problem @@ -99,11 +102,11 @@ class PINN(object): def _compute_norm(self, vec): """ Compute the norm of the `vec` one-dimensional tensor based on the - `self.error_norm` attribute. + `self.error_norm` attribute. .. todo: complete - :param vec torch.tensor: the tensor + :param torch.Tensor vec: the tensor """ if isinstance(self.error_norm, int): return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe) @@ -115,7 +118,11 @@ class PINN(object): raise RuntimeError def save_state(self, filename): + """ + Save the state of the model. + :param str filename: the filename to save the state to. + """ checkpoint = { 'epoch': self.trained_epoch, 'model_state': self.model.state_dict(), @@ -133,6 +140,11 @@ class PINN(object): torch.save(checkpoint, filename) def load_state(self, filename): + """ + Load the state of the model. + + :param str filename: the filename to load the state from. + """ checkpoint = torch.load(filename) self.model.load_state_dict(checkpoint['model_state']) @@ -298,32 +310,32 @@ class PINN(object): return sum(losses).item() - def error(self, dtype='l2', res=100): + # def error(self, dtype='l2', res=100): - import numpy as np - if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None: - pts_container = [] - for mn, mx in self.problem.domain_bound: - pts_container.append(np.linspace(mn, mx, res)) - grids_container = np.meshgrid(*pts_container) - Z_true = self.problem.truth_solution(*grids_container) + # import numpy as np + # if hasattr(self.problem, 'truth_solution') and self.problem.truth_solution is not None: + # pts_container = [] + # for mn, mx in self.problem.domain_bound: + # pts_container.append(np.linspace(mn, mx, res)) + # grids_container = np.meshgrid(*pts_container) + # Z_true = self.problem.truth_solution(*grids_container) - elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None: - grids_container = self.problem.data_solution['grid'] - Z_true = self.problem.data_solution['grid_solution'] - try: - unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to( - dtype=self.dtype, device=self.device) - Z_pred = self.model(unrolled_pts) - Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape) + # elif hasattr(self.problem, 'data_solution') and self.problem.data_solution is not None: + # grids_container = self.problem.data_solution['grid'] + # Z_true = self.problem.data_solution['grid_solution'] + # try: + # unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to( + # dtype=self.dtype, device=self.device) + # Z_pred = self.model(unrolled_pts) + # Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape) - if dtype == 'l2': - return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true) - else: - # TODO H1 - pass - except: - print("") - print("Something went wrong...") - print( - "Not able to compute the error. Please pass a data solution or a true solution") + # if dtype == 'l2': + # return np.linalg.norm(Z_pred - Z_true)/np.linalg.norm(Z_true) + # else: + # # TODO H1 + # pass + # except: + # print("") + # print("Something went wrong...") + # print( + # "Not able to compute the error. Please pass a data solution or a true solution") diff --git a/tests/test_condition.py b/tests/test_condition.py new file mode 100644 index 0000000..56ae665 --- /dev/null +++ b/tests/test_condition.py @@ -0,0 +1,42 @@ +import torch +import pytest + +from pina import LabelTensor, Condition, Span, PINN +from pina.problem import SpatialProblem +from pina.model import FeedForward +from pina.operators import nabla + + +example_domain = Span({'x': [0, 1], 'y': [0, 1]}) +def example_dirichlet(input_, output_): + value = 0.0 + return output_.extract(['u']) - value +example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z']) +example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b']) + +def test_init_inputoutput(): + Condition(input_points=example_input_pts, output_points=example_output_pts) + with pytest.raises(ValueError): + Condition(example_input_pts, example_output_pts) + with pytest.raises(TypeError): + Condition(input_points=3., output_points='example') + with pytest.raises(TypeError): + Condition(input_points=example_domain, output_points=example_dirichlet) + +def test_init_locfunc(): + Condition(location=example_domain, function=example_dirichlet) + with pytest.raises(ValueError): + Condition(example_domain, example_dirichlet) + with pytest.raises(TypeError): + Condition(location=3., function='example') + with pytest.raises(TypeError): + Condition(location=example_input_pts, function=example_output_pts) + +def test_init_inputfunc(): + Condition(input_points=example_input_pts, function=example_dirichlet) + with pytest.raises(ValueError): + Condition(example_domain, example_dirichlet) + with pytest.raises(TypeError): + Condition(input_points=3., function='example') + with pytest.raises(TypeError): + Condition(input_points=example_domain, funtion=example_output_pts) \ No newline at end of file