diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 68ee796..9062b3b 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -41,6 +41,7 @@ Solvers CausalPINN CompetitivePINN SAPINN + RBAPINN Supervised solver ReducedOrderModelSolver GAROM diff --git a/docs/source/_rst/solvers/rba_pinn.rst b/docs/source/_rst/solvers/rba_pinn.rst new file mode 100644 index 0000000..b964cce --- /dev/null +++ b/docs/source/_rst/solvers/rba_pinn.rst @@ -0,0 +1,7 @@ +RBAPINN +======== +.. currentmodule:: pina.solvers.pinns.rbapinn + +.. autoclass:: RBAPINN + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/solvers/__init__.py b/pina/solvers/__init__.py index 6b75566..7bb988d 100644 --- a/pina/solvers/__init__.py +++ b/pina/solvers/__init__.py @@ -6,6 +6,7 @@ __all__ = [ "CausalPINN", "CompetitivePINN", "SAPINN", + "RBAPINN", "SupervisedSolver", "ReducedOrderModelSolver", "GAROM", diff --git a/pina/solvers/pinns/__init__.py b/pina/solvers/pinns/__init__.py index c8aa904..8c77966 100644 --- a/pina/solvers/pinns/__init__.py +++ b/pina/solvers/pinns/__init__.py @@ -5,6 +5,7 @@ __all__ = [ "CausalPINN", "CompetitivePINN", "SAPINN", + "RBAPINN", ] from .basepinn import PINNInterface @@ -13,3 +14,4 @@ from .gpinn import GPINN from .causalpinn import CausalPINN from .competitive_pinn import CompetitivePINN from .sapinn import SAPINN +from .rbapinn import RBAPINN diff --git a/pina/solvers/pinns/rbapinn.py b/pina/solvers/pinns/rbapinn.py new file mode 100644 index 0000000..770b0a7 --- /dev/null +++ b/pina/solvers/pinns/rbapinn.py @@ -0,0 +1,170 @@ +""" Module for RBAPINN. """ + +from copy import deepcopy +import torch +from torch.optim.lr_scheduler import ConstantLR +from .pinn import PINN +from ...utils import check_consistency + + +class RBAPINN(PINN): + r""" + Residual-based Attention PINN (RBAPINN) solver class. + This class implements Residual-based Attention Physics Informed Neural + Network solvers, using a user specified ``model`` to solve a specific + ``problem``. It can be used for solving both forward and inverse problems. + + The Residual-based Attention Physics Informed Neural Network aims to find + the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` + of the differential problem: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + minimizing the loss function + + .. math:: + + \mathcal{L}_{\rm{problem}} = \frac{1}{N} \sum_{i=1}^{N_\Omega} + \lambda_{\Omega}^{i} \mathcal{L} \left( \mathcal{A} + [\mathbf{u}](\mathbf{x}) \right) + \frac{1}{N} + \sum_{i=1}^{N_{\partial\Omega}} + \lambda_{\partial\Omega}^{i} \mathcal{L} + \left( \mathcal{B}[\mathbf{u}](\mathbf{x}) + \right), + + denoting the weights as + :math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and + :math:`\lambda_{\partial \Omega}^1, \dots, + \lambda_{\Omega}^{N_\partial \Omega}` + for :math:`\Omega` and :math:`\partial \Omega`, respectively. + + Residual-based Attention Physics Informed Neural Network computes + the weights by updating them at every epoch as follows + + .. math:: + + \lambda_i^{k+1} \leftarrow \gamma\lambda_i^{k} + + \eta\frac{\lvert r_i\rvert}{\max_j \lvert r_j\rvert}, + + where :math:`r_i` denotes the residual at point :math:`i`, + :math:`\gamma` denotes the decay rate, and :math:`\eta` is + the learning rate for the weights' update. + + .. seealso:: + **Original reference**: Sokratis J. Anagnostopoulos, Juan D. Toscano, + Nikolaos Stergiopulos, and George E. Karniadakis. + "Residual-based attention and connection to information + bottleneck theory in PINNs". + Computer Methods in Applied Mechanics and Engineering 421 (2024): 116805 + DOI: `10.1016/ + j.cma.2024.116805 `_. + """ + + def __init__( + self, + problem, + model, + extra_features=None, + loss=torch.nn.MSELoss(), + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, + scheduler=ConstantLR, + scheduler_kwargs={"factor": 1, "total_iters": 0}, + eta=0.001, + gamma=0.999, + ): + """ + :param AbstractProblem problem: The formulation of the problem. + :param torch.nn.Module model: The neural network model to use. + :param torch.nn.Module extra_features: The additional input + features to use as augmented input. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + :param torch.optim.Optimizer optimizer: The neural network optimizer to + use; default is :class:`torch.optim.Adam`. + :param dict optimizer_kwargs: Optimizer constructor keyword args. + :param torch.optim.LRScheduler scheduler: Learning + rate scheduler. + :param dict scheduler_kwargs: LR scheduler constructor keyword args. + :param float | int eta: The learning rate for the + weights of the residual. + :param float gamma: The decay parameter in the update of the weights + of the residual. + """ + super().__init__( + problem=problem, + model=model, + extra_features=extra_features, + loss=loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + ) + + # check consistency + check_consistency(eta, (float, int)) + check_consistency(gamma, float) + self.eta = eta + self.gamma = gamma + + # initialize weights + self.weights = {} + for condition_name in problem.conditions: + self.weights[condition_name] = 0 + + # define vectorial loss + self._vectorial_loss = deepcopy(loss) + self._vectorial_loss.reduction = "none" + + def _vect_to_scalar(self, loss_value): + """ + Elaboration of the pointwise loss. + + :param LabelTensor loss_value: the matrix of pointwise loss. + + :return: the scalar loss. + :rtype LabelTensor + """ + if self.loss.reduction == "mean": + ret = torch.mean(loss_value) + elif self.loss.reduction == "sum": + ret = torch.sum(loss_value) + else: + raise RuntimeError( + f"Invalid reduction, got {self.loss.reduction} " + "but expected mean or sum." + ) + return ret + + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the residual-based attention PINN + solver based on given samples and equation. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The physics loss calculated based on given + samples and equation. + :rtype: LabelTensor + """ + residual = self.compute_residual(samples=samples, equation=equation) + cond = self.current_condition_name + + r_norm = self.eta * torch.abs(residual) / torch.max(torch.abs(residual)) + self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach() + + loss_value = self._vectorial_loss( + torch.zeros_like(residual, requires_grad=True), residual + ) + + self.store_log(loss_value=float(self._vect_to_scalar(loss_value))) + + return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value) diff --git a/setup.py b/setup.py index a1a8b66..2aacaa1 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ VERSION = meta['__version__'] KEYWORDS = 'machine-learning deep-learning modeling pytorch ode neural-networks differential-equations pde hacktoberfest pinn physics-informed physics-informed-neural-networks neural-operators equation-learning lightining' REQUIRED = [ - 'numpy', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning' + 'numpy<2.0', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning' ] EXTRAS = { diff --git a/tests/test_solvers/test_rba_pinn.py b/tests/test_solvers/test_rba_pinn.py new file mode 100644 index 0000000..6622666 --- /dev/null +++ b/tests/test_solvers/test_rba_pinn.py @@ -0,0 +1,437 @@ +import torch +import pytest + +from pina.problem import SpatialProblem, InverseProblem +from pina.operators import laplacian +from pina.geometry import CartesianDomain +from pina import Condition, LabelTensor +from pina.solvers import RBAPINN as PINN +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.loss import LpLoss + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + + +my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]]), ['u']) +in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) +out2_ = LabelTensor(torch.rand(60, 1), ['u']) + + +class InversePoisson(SpatialProblem, InverseProblem): + ''' + Problem definition for the Poisson equation. + ''' + output_variables = ['u'] + x_min = -2 + x_max = 2 + y_min = -2 + y_max = 2 + data_input = LabelTensor(torch.rand(10, 2), ['x', 'y']) + data_output = LabelTensor(torch.rand(10, 1), ['u']) + spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]}) + # define the ranges for the parameters + unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]}) + + def laplace_equation(input_, output_, params_): + ''' + Laplace equation with a force term. + ''' + force_term = torch.exp( + - 2*(input_.extract(['x']) - params_['mu1'])**2 + - 2*(input_.extract(['y']) - params_['mu2'])**2) + delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y']) + + return delta_u - force_term + + # define the conditions for the loss (boundary conditions, equation, data) + conditions = { + 'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max], + 'y': y_max}), + equation=FixedValue(0.0, components=['u'])), + 'gamma2': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': y_min + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma3': Condition(location=CartesianDomain( + {'x': x_max, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma4': Condition(location=CartesianDomain( + {'x': x_min, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'D': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': [y_min, y_max] + }), + equation=Equation(laplace_equation)), + 'data': Condition(input_points=data_input.extract(['x', 'y']), + output_points=data_output) + } + + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 1}), + equation=FixedValue(0.0)), + 'gamma2': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 0}), + equation=FixedValue(0.0)), + 'gamma3': Condition( + location=CartesianDomain({'x': 1, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'gamma4': Condition( + location=CartesianDomain({'x': 0, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'D': Condition( + input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), + equation=my_laplace), + 'data': Condition( + input_points=in_, + output_points=out_), + 'data2': Condition( + input_points=in2_, + output_points=out2_) + } + + def poisson_sol(self, pts): + return -(torch.sin(pts.extract(['x']) * torch.pi) * + torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) + + truth_solution = poisson_sol + + +class myFeature(torch.nn.Module): + """ + Feature: sin(x) + """ + + def __init__(self): + super(myFeature, self).__init__() + + def forward(self, x): + t = (torch.sin(x.extract(['x']) * torch.pi) * + torch.sin(x.extract(['y']) * torch.pi)) + return LabelTensor(t, ['sin(x)sin(y)']) + + +# make the problem +poisson_problem = Poisson() +model = FeedForward(len(poisson_problem.input_variables), + len(poisson_problem.output_variables)) +model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) +extra_feats = [myFeature()] + + +def test_constructor(): + PINN(problem=poisson_problem, model=model, extra_features=None) + with pytest.raises(ValueError): + PINN(problem=poisson_problem, model=model, eta='x') + PINN(problem=poisson_problem, model=model, gamma='x') + + +def test_constructor_extra_feats(): + model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) + PINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + + +def test_train_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem = poisson_problem, model=model, + extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +def test_train_restore(): + tmpdir = "tests/tmp_restore" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=5, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') + t = ntrainer.train( + ckpt_path=f'{tmpdir}/lightning_logs/version_0/' + 'checkpoints/epoch=4-step=10.ckpt') + import shutil + shutil.rmtree(tmpdir) + + +def test_train_load(): + tmpdir = "tests/tmp_load" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +def test_train_inverse_problem_cpu(): + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem = poisson_problem, model=model, + extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +# # TODO does not currently work +# def test_train_inverse_problem_restore(): +# tmpdir = "tests/tmp_restore_inv" +# poisson_problem = InversePoisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] +# n = 100 +# poisson_problem.discretise_domain(n, 'random', locations=boundaries) +# pinn = PINN(problem=poisson_problem, +# model=model, +# extra_features=None, +# loss=LpLoss()) +# trainer = Trainer(solver=pinn, +# max_epochs=5, +# accelerator='cpu', +# default_root_dir=tmpdir) +# trainer.train() +# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') +# t = ntrainer.train( +# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') +# import shutil +# shutil.rmtree(tmpdir) + + +def test_train_inverse_problem_load(): + tmpdir = "tests/tmp_load_inv" + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +# # TODO fix asap. Basically sampling few variables +# # works only if both variables are in a range. +# # if one is fixed and the other not, this will +# # not work. This test also needs to be fixed and +# # insert in test problem not in test pinn. +# def test_train_cpu_sampling_few_vars(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x']) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y']) +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'}) +# trainer.train() + + +def test_train_extra_feats_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') + trainer.train() + + +# TODO, fix GitHub actions to run also on GPU +# def test_train_gpu(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_gpu(): #TODO fix ASAP +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_2(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_extra_feats(): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) + + +# def test_train_2_extra_feats(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_optimizer_kwargs(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_lr_scheduler(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN( +# problem, +# model, +# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, +# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} +# ) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# # def test_train_batch(): +# # pinn = PINN(problem, model, batch_size=6) +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + + +# # def test_train_batch_2(): +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # expected_keys = [[], list(range(0, 50, 3))] +# # param = [0, 3] +# # for i, truth_key in zip(param, expected_keys): +# # pinn = PINN(problem, model, batch_size=6) +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(50, save_loss=i) +# # assert list(pinn.history_loss.keys()) == truth_key + + +# if torch.cuda.is_available(): + +# # def test_gpu_train(): +# # pinn = PINN(problem, model, batch_size=20, device='cuda') +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 100 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + +# def test_gpu_train_nobatch(): +# pinn = PINN(problem, model, batch_size=None, device='cuda') +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 100 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) +