lh solved (#55)
This commit is contained in:
10
pina/span.py
10
pina/span.py
@@ -3,6 +3,7 @@ import torch
|
|||||||
|
|
||||||
from .location import Location
|
from .location import Location
|
||||||
from .label_tensor import LabelTensor
|
from .label_tensor import LabelTensor
|
||||||
|
from .utils import torch_lhs
|
||||||
|
|
||||||
|
|
||||||
class Span(Location):
|
class Span(Location):
|
||||||
@@ -41,10 +42,7 @@ class Span(Location):
|
|||||||
elif mode == 'grid':
|
elif mode == 'grid':
|
||||||
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
||||||
elif mode == 'lh' or mode == 'latin':
|
elif mode == 'lh' or mode == 'latin':
|
||||||
from scipy.stats import qmc
|
pts = torch_lhs(n, dim)
|
||||||
sampler = qmc.LatinHypercube(d=dim)
|
|
||||||
pts = sampler.random(n)
|
|
||||||
pts = torch.from_numpy(pts)
|
|
||||||
|
|
||||||
pts *= bounds[:, 1] - bounds[:, 0]
|
pts *= bounds[:, 1] - bounds[:, 0]
|
||||||
pts += bounds[:, 0]
|
pts += bounds[:, 0]
|
||||||
@@ -83,7 +81,7 @@ class Span(Location):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _Nd_sampler(n, mode, variables):
|
def _Nd_sampler(n, mode, variables):
|
||||||
""" Sample ll the variables together """
|
""" Sample all the variables together """
|
||||||
pairs = [(k, v) for k, v in self.range_.items() if k in variables]
|
pairs = [(k, v) for k, v in self.range_.items() if k in variables]
|
||||||
keys, values = map(list, zip(*pairs))
|
keys, values = map(list, zip(*pairs))
|
||||||
bounds = torch.tensor(values)
|
bounds = torch.tensor(values)
|
||||||
@@ -107,7 +105,7 @@ class Span(Location):
|
|||||||
|
|
||||||
if mode in ['grid', 'chebyshev']:
|
if mode in ['grid', 'chebyshev']:
|
||||||
return _1d_sampler(n, mode, variables)
|
return _1d_sampler(n, mode, variables)
|
||||||
elif mode in ['random', 'lhs']:
|
elif mode in ['random', 'lh', 'latin']:
|
||||||
return _Nd_sampler(n, mode, variables)
|
return _Nd_sampler(n, mode, variables)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'mode={mode} is not valid.')
|
raise ValueError(f'mode={mode} is not valid.')
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from torch.utils.data import DataLoader, default_collate, ConcatDataset
|
|||||||
|
|
||||||
from .label_tensor import LabelTensor
|
from .label_tensor import LabelTensor
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
||||||
"""
|
"""
|
||||||
@@ -49,6 +51,40 @@ def merge_two_tensors(tensor1, tensor2):
|
|||||||
return tensor1.append(tensor2)
|
return tensor1.append(tensor2)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_lhs(n, dim):
|
||||||
|
"""Latin Hypercube Sampling torch routine.
|
||||||
|
Sampling in range $[0, 1)^d$.
|
||||||
|
|
||||||
|
:param int n: number of samples
|
||||||
|
:param int dim: dimensions of latin hypercube
|
||||||
|
:return: samples
|
||||||
|
:rtype: torch.tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(n, int):
|
||||||
|
raise TypeError('number of point n must be int')
|
||||||
|
|
||||||
|
if not isinstance(dim, int):
|
||||||
|
raise TypeError('dim must be int')
|
||||||
|
|
||||||
|
if dim < 1:
|
||||||
|
raise ValueError('dim must be greater than one')
|
||||||
|
|
||||||
|
samples = torch.rand(size=(n, dim))
|
||||||
|
|
||||||
|
perms = torch.tile(torch.arange(1, n + 1), (dim, 1))
|
||||||
|
|
||||||
|
for row in range(dim):
|
||||||
|
idx_perm = torch.randperm(perms.shape[-1])
|
||||||
|
perms[row, :] = perms[row, idx_perm]
|
||||||
|
|
||||||
|
perms = perms.T
|
||||||
|
|
||||||
|
samples = (perms - samples) / n
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
class PinaDataset():
|
class PinaDataset():
|
||||||
|
|
||||||
def __init__(self, pinn) -> None:
|
def __init__(self, pinn) -> None:
|
||||||
@@ -108,4 +144,4 @@ class PinaDataset():
|
|||||||
return {self._location: tensor}
|
return {self._location: tensor}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._len
|
return self._len
|
||||||
@@ -63,6 +63,12 @@ def test_span_pts():
|
|||||||
pinn.span_pts(n, 'random', locations=['D'])
|
pinn.span_pts(n, 'random', locations=['D'])
|
||||||
assert pinn.input_pts['D'].shape[0] == n
|
assert pinn.input_pts['D'].shape[0] == n
|
||||||
|
|
||||||
|
pinn.span_pts(n, 'latin', locations=['D'])
|
||||||
|
assert pinn.input_pts['D'].shape[0] == n
|
||||||
|
|
||||||
|
pinn.span_pts(n, 'lh', locations=['D'])
|
||||||
|
assert pinn.input_pts['D'].shape[0] == n
|
||||||
|
|
||||||
|
|
||||||
def test_train():
|
def test_train():
|
||||||
pinn = PINN(problem, model)
|
pinn = PINN(problem, model)
|
||||||
@@ -124,4 +130,4 @@ if torch.cuda.is_available():
|
|||||||
n = 100
|
n = 100
|
||||||
pinn.span_pts(n, 'grid', boundaries)
|
pinn.span_pts(n, 'grid', boundaries)
|
||||||
pinn.span_pts(n, 'grid', locations=['D'])
|
pinn.span_pts(n, 'grid', locations=['D'])
|
||||||
pinn.train(5)
|
pinn.train(5)
|
||||||
Reference in New Issue
Block a user