import torch from pina.problem import AbstractProblem from pina import Condition, LabelTensor from pina.solvers import GAROM from pina.trainer import Trainer import torch.nn as nn import matplotlib.tri as tri def func(x, mu1, mu2): import torch x_m1 = (x[:, 0] - mu1).pow(2) x_m2 = (x[:, 1] - mu2).pow(2) norm = x[:, 0]**2 + x[:, 1]**2 return torch.exp(-(x_m1 + x_m2)) class ParametricGaussian(AbstractProblem): output_variables = [f'u_{i}' for i in range(900)] # params xx = torch.linspace(-1, 1, 20) yy = xx params = LabelTensor(torch.cartesian_prod(xx, yy), labels=['mu1', 'mu2']) # define domain x = torch.linspace(-1, 1, 30) domain = torch.cartesian_prod(x, x) triang = tri.Triangulation(domain[:, 0], domain[:, 1]) sol = [] for p in params: sol.append(func(domain, p[0], p[1])) snapshots = LabelTensor(torch.stack(sol), labels=output_variables) # define conditions conditions = { 'data': Condition(input_points=params, output_points=snapshots) } # simple Generator Network class Generator(nn.Module): def __init__(self, input_dimension, parameters_dimension, noise_dimension, activation=torch.nn.SiLU): super().__init__() self._noise_dimension = noise_dimension self._activation = activation self.model = torch.nn.Sequential( torch.nn.Linear(6 * self._noise_dimension, input_dimension // 6), self._activation(), torch.nn.Linear(input_dimension // 6, input_dimension // 3), self._activation(), torch.nn.Linear(input_dimension // 3, input_dimension)) self.condition = torch.nn.Sequential( torch.nn.Linear(parameters_dimension, 2 * self._noise_dimension), self._activation(), torch.nn.Linear(2 * self._noise_dimension, 5 * self._noise_dimension)) def forward(self, param): # uniform sampling in [-1, 1] z = torch.rand(size=(param.shape[0], self._noise_dimension), device=param.device, dtype=param.dtype, requires_grad=True) z = 2. * z - 1. # conditioning by concatenation of mapped parameters input_ = torch.cat((z, self.condition(param)), dim=-1) out = self.model(input_) return out # Simple Discriminator Network class Discriminator(nn.Module): def __init__(self, input_dimension, parameter_dimension, hidden_dimension, activation=torch.nn.ReLU): super().__init__() self._activation = activation self.encoding = torch.nn.Sequential( torch.nn.Linear(input_dimension, input_dimension // 3), self._activation(), torch.nn.Linear(input_dimension // 3, input_dimension // 6), self._activation(), torch.nn.Linear(input_dimension // 6, hidden_dimension)) self.decoding = torch.nn.Sequential( torch.nn.Linear(2 * hidden_dimension, input_dimension // 6), self._activation(), torch.nn.Linear(input_dimension // 6, input_dimension // 3), self._activation(), torch.nn.Linear(input_dimension // 3, input_dimension), ) self.condition = torch.nn.Sequential( torch.nn.Linear(parameter_dimension, hidden_dimension // 2), self._activation(), torch.nn.Linear(hidden_dimension // 2, hidden_dimension)) def forward(self, data): x, condition = data encoding = self.encoding(x) conditioning = torch.cat((encoding, self.condition(condition)), dim=-1) decoding = self.decoding(conditioning) return decoding problem = ParametricGaussian() def test_constructor(): GAROM(problem=problem, generator=Generator(input_dimension=900, parameters_dimension=2, noise_dimension=12), discriminator=Discriminator(input_dimension=900, parameter_dimension=2, hidden_dimension=64)) def test_train_cpu(): solver = GAROM(problem=problem, generator=Generator(input_dimension=900, parameters_dimension=2, noise_dimension=12), discriminator=Discriminator(input_dimension=900, parameter_dimension=2, hidden_dimension=64)) trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu', batch_size=20) trainer.train() def test_sample(): solver = GAROM(problem=problem, generator=Generator(input_dimension=900, parameters_dimension=2, noise_dimension=12), discriminator=Discriminator(input_dimension=900, parameter_dimension=2, hidden_dimension=64)) solver.sample(problem.params) assert solver.sample(problem.params).shape == problem.snapshots.shape def test_forward(): solver = GAROM(problem=problem, generator=Generator(input_dimension=900, parameters_dimension=2, noise_dimension=12), discriminator=Discriminator(input_dimension=900, parameter_dimension=2, hidden_dimension=64)) solver(problem.params, mc_steps=100, variance=True) assert solver(problem.params).shape == problem.snapshots.shape