Update of LabelTensor class and fix Simplex domain (#362)
*Implement new methods in LabelTensor and fix operators
This commit is contained in:
committed by
Nicola Demo
parent
fdb8f65143
commit
7528f6ef74
@@ -18,27 +18,27 @@ def test_init_inputoutput():
|
||||
Condition(input_points=example_input_pts, output_points=example_output_pts)
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_input_pts, example_output_pts)
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=3., output_points='example')
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=example_domain, output_points=example_domain)
|
||||
test_init_inputoutput()
|
||||
|
||||
|
||||
def test_init_locfunc():
|
||||
Condition(location=example_domain, equation=FixedValue(0.0))
|
||||
def test_init_domainfunc():
|
||||
Condition(domain=example_domain, equation=FixedValue(0.0))
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_domain, FixedValue(0.0))
|
||||
with pytest.raises(TypeError):
|
||||
Condition(location=3., equation='example')
|
||||
with pytest.raises(TypeError):
|
||||
Condition(location=example_input_pts, equation=example_output_pts)
|
||||
with pytest.raises(ValueError):
|
||||
Condition(domain=3., equation='example')
|
||||
with pytest.raises(ValueError):
|
||||
Condition(domain=example_input_pts, equation=example_output_pts)
|
||||
|
||||
|
||||
def test_init_inputfunc():
|
||||
Condition(input_points=example_input_pts, equation=FixedValue(0.0))
|
||||
with pytest.raises(ValueError):
|
||||
Condition(example_domain, FixedValue(0.0))
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=3., equation='example')
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
Condition(input_points=example_domain, equation=example_output_pts)
|
||||
|
||||
@@ -40,7 +40,6 @@ def test_constructor():
|
||||
LabelTensor(torch.tensor([[-.5, .5]]), labels=["x", "y"]),
|
||||
])
|
||||
|
||||
|
||||
def test_sample():
|
||||
# sampling inside
|
||||
simplex = SimplexDomain([
|
||||
|
||||
@@ -2,7 +2,6 @@ import torch
|
||||
import pytest
|
||||
|
||||
from pina.label_tensor import LabelTensor
|
||||
#import pina
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
labels_column = {
|
||||
@@ -22,8 +21,7 @@ labels_all = labels_column | labels_row
|
||||
|
||||
@pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all, labels_list])
|
||||
def test_constructor(labels):
|
||||
LabelTensor(data, labels)
|
||||
|
||||
print(LabelTensor(data, labels))
|
||||
|
||||
def test_wrong_constructor():
|
||||
with pytest.raises(ValueError):
|
||||
@@ -92,7 +90,7 @@ def test_extract_3D():
|
||||
))
|
||||
assert tensor2.ndim == tensor.ndim
|
||||
assert tensor2.shape == tensor.shape
|
||||
assert tensor.labels == tensor2.labels
|
||||
assert tensor.full_labels == tensor2.full_labels
|
||||
assert new.shape != tensor.shape
|
||||
|
||||
def test_concatenation_3D():
|
||||
@@ -104,9 +102,9 @@ def test_concatenation_3D():
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2])
|
||||
assert lt_cat.shape == (70, 3, 4)
|
||||
assert lt_cat.labels[0]['dof'] == range(70)
|
||||
assert lt_cat.labels[1]['dof'] == range(3)
|
||||
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
assert lt_cat.full_labels[0]['dof'] == range(70)
|
||||
assert lt_cat.full_labels[1]['dof'] == range(3)
|
||||
assert lt_cat.full_labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
|
||||
data_1 = torch.rand(20, 3, 4)
|
||||
labels_1 = ['x', 'y', 'z', 'w']
|
||||
@@ -116,9 +114,9 @@ def test_concatenation_3D():
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2], dim=1)
|
||||
assert lt_cat.shape == (20, 5, 4)
|
||||
assert lt_cat.labels[0]['dof'] == range(20)
|
||||
assert lt_cat.labels[1]['dof'] == range(5)
|
||||
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
assert lt_cat.full_labels[0]['dof'] == range(20)
|
||||
assert lt_cat.full_labels[1]['dof'] == range(5)
|
||||
assert lt_cat.full_labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
|
||||
data_1 = torch.rand(20, 3, 2)
|
||||
labels_1 = ['x', 'y']
|
||||
@@ -128,9 +126,9 @@ def test_concatenation_3D():
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
|
||||
assert lt_cat.shape == (20, 3, 5)
|
||||
assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a']
|
||||
assert lt_cat.labels[0]['dof'] == range(20)
|
||||
assert lt_cat.labels[1]['dof'] == range(3)
|
||||
assert lt_cat.full_labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a']
|
||||
assert lt_cat.full_labels[0]['dof'] == range(20)
|
||||
assert lt_cat.full_labels[1]['dof'] == range(3)
|
||||
|
||||
data_1 = torch.rand(20, 2, 4)
|
||||
labels_1 = ['x', 'y', 'z', 'w']
|
||||
@@ -140,7 +138,6 @@ def test_concatenation_3D():
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
with pytest.raises(ValueError):
|
||||
LabelTensor.cat([lt1, lt2], dim=2)
|
||||
|
||||
data_1 = torch.rand(20, 3, 2)
|
||||
labels_1 = ['x', 'y']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
@@ -149,9 +146,9 @@ def test_concatenation_3D():
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt_cat = LabelTensor.cat([lt1, lt2], dim=2)
|
||||
assert lt_cat.shape == (20, 3, 5)
|
||||
assert lt_cat.labels[2]['dof'] == range(5)
|
||||
assert lt_cat.labels[0]['dof'] == range(20)
|
||||
assert lt_cat.labels[1]['dof'] == range(3)
|
||||
assert lt_cat.full_labels[2]['dof'] == range(5)
|
||||
assert lt_cat.full_labels[0]['dof'] == range(20)
|
||||
assert lt_cat.full_labels[1]['dof'] == range(3)
|
||||
|
||||
|
||||
def test_summation():
|
||||
@@ -165,7 +162,7 @@ def test_summation():
|
||||
assert lt_sum.ndim == lt_sum.ndim
|
||||
assert lt_sum.shape[0] == 20
|
||||
assert lt_sum.shape[1] == 3
|
||||
assert lt_sum.labels == labels_all
|
||||
assert lt_sum.full_labels == labels_all
|
||||
assert torch.eq(lt_sum.tensor, torch.ones(20,3)*2).all()
|
||||
lt1 = LabelTensor(torch.ones(20,3), labels_all)
|
||||
lt2 = LabelTensor(torch.ones(20,3), labels_all)
|
||||
@@ -174,29 +171,92 @@ def test_summation():
|
||||
assert lt_sum.ndim == lt_sum.ndim
|
||||
assert lt_sum.shape[0] == 20
|
||||
assert lt_sum.shape[1] == 3
|
||||
assert lt_sum.labels == labels_all
|
||||
assert lt_sum.full_labels == labels_all
|
||||
assert torch.eq(lt_sum.tensor, torch.ones(20,3)*2).all()
|
||||
|
||||
def test_append_3D():
|
||||
data_1 = torch.rand(20, 3, 4)
|
||||
labels_1 = ['x', 'y', 'z', 'w']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(50, 3, 4)
|
||||
labels_2 = ['x', 'y', 'z', 'w']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt1 = lt1.append(lt2)
|
||||
assert lt1.shape == (70, 3, 4)
|
||||
assert lt1.labels[0]['dof'] == range(70)
|
||||
assert lt1.labels[1]['dof'] == range(3)
|
||||
assert lt1.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
data_1 = torch.rand(20, 3, 2)
|
||||
labels_1 = ['x', 'y']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 3, 2)
|
||||
labels_2 = ['z', 'w']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt1 = lt1.append(lt2, mode='cross')
|
||||
lt1 = lt1.append(lt2)
|
||||
assert lt1.shape == (20, 3, 4)
|
||||
assert lt1.labels[0]['dof'] == range(20)
|
||||
assert lt1.labels[1]['dof'] == range(3)
|
||||
assert lt1.labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
assert lt1.full_labels[0]['dof'] == range(20)
|
||||
assert lt1.full_labels[1]['dof'] == range(3)
|
||||
assert lt1.full_labels[2]['dof'] == ['x', 'y', 'z', 'w']
|
||||
|
||||
def test_append_2D():
|
||||
data_1 = torch.rand(20, 2)
|
||||
labels_1 = ['x', 'y']
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 2)
|
||||
labels_2 = ['z', 'w']
|
||||
lt2 = LabelTensor(data_2, labels_2)
|
||||
lt1 = lt1.append(lt2, mode='cross')
|
||||
assert lt1.shape == (400, 4)
|
||||
assert lt1.full_labels[0]['dof'] == range(400)
|
||||
assert lt1.full_labels[1]['dof'] == ['x', 'y', 'z', 'w']
|
||||
|
||||
def test_vstack_3D():
|
||||
data_1 = torch.rand(20, 3, 2)
|
||||
labels_1 = {1:{'dof': ['a', 'b', 'c'], 'name': 'first'}, 2: {'dof': ['x', 'y'], 'name': 'second'}}
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 3, 2)
|
||||
labels_1 = {1:{'dof': ['a', 'b', 'c'], 'name': 'first'}, 2: {'dof': ['x', 'y'], 'name': 'second'}}
|
||||
lt2 = LabelTensor(data_2, labels_1)
|
||||
lt_stacked = LabelTensor.vstack([lt1, lt2])
|
||||
assert lt_stacked.shape == (40, 3, 2)
|
||||
assert lt_stacked.full_labels[0]['dof'] == range(40)
|
||||
assert lt_stacked.full_labels[1]['dof'] == ['a', 'b', 'c']
|
||||
assert lt_stacked.full_labels[2]['dof'] == ['x', 'y']
|
||||
assert lt_stacked.full_labels[1]['name'] == 'first'
|
||||
assert lt_stacked.full_labels[2]['name'] == 'second'
|
||||
|
||||
def test_vstack_2D():
|
||||
data_1 = torch.rand(20, 2)
|
||||
labels_1 = { 1: {'dof': ['x', 'y'], 'name': 'second'}}
|
||||
lt1 = LabelTensor(data_1, labels_1)
|
||||
data_2 = torch.rand(20, 2)
|
||||
labels_1 = { 1: {'dof': ['x', 'y'], 'name': 'second'}}
|
||||
lt2 = LabelTensor(data_2, labels_1)
|
||||
lt_stacked = LabelTensor.vstack([lt1, lt2])
|
||||
assert lt_stacked.shape == (40, 2)
|
||||
assert lt_stacked.full_labels[0]['dof'] == range(40)
|
||||
assert lt_stacked.full_labels[1]['dof'] == ['x', 'y']
|
||||
assert lt_stacked.full_labels[0]['name'] == 0
|
||||
assert lt_stacked.full_labels[1]['name'] == 'second'
|
||||
|
||||
def test_sorting():
|
||||
data = torch.ones(20, 5)
|
||||
data[:,0] = data[:,0]*4
|
||||
data[:,1] = data[:,1]*2
|
||||
data[:,2] = data[:,2]
|
||||
data[:,3] = data[:,3]*5
|
||||
data[:,4] = data[:,4]*3
|
||||
labels = ['d', 'b', 'a', 'e', 'c']
|
||||
lt_data = LabelTensor(data, labels)
|
||||
lt_sorted = LabelTensor.sort_labels(lt_data)
|
||||
assert lt_sorted.shape == (20,5)
|
||||
assert lt_sorted.labels == ['a', 'b', 'c', 'd', 'e']
|
||||
assert torch.eq(lt_sorted.tensor[:,0], torch.ones(20) * 1).all()
|
||||
assert torch.eq(lt_sorted.tensor[:,1], torch.ones(20) * 2).all()
|
||||
assert torch.eq(lt_sorted.tensor[:,2], torch.ones(20) * 3).all()
|
||||
assert torch.eq(lt_sorted.tensor[:,3], torch.ones(20) * 4).all()
|
||||
assert torch.eq(lt_sorted.tensor[:,4], torch.ones(20) * 5).all()
|
||||
|
||||
data = torch.ones(20, 4, 5)
|
||||
data[:,0,:] = data[:,0]*4
|
||||
data[:,1,:] = data[:,1]*2
|
||||
data[:,2,:] = data[:,2]
|
||||
data[:,3,:] = data[:,3]*3
|
||||
labels = {1: {'dof': ['d', 'b', 'a', 'c'], 'name': 1}}
|
||||
lt_data = LabelTensor(data, labels)
|
||||
lt_sorted = LabelTensor.sort_labels(lt_data, dim=1)
|
||||
assert lt_sorted.shape == (20,4, 5)
|
||||
assert lt_sorted.full_labels[1]['dof'] == ['a', 'b', 'c', 'd']
|
||||
assert torch.eq(lt_sorted.tensor[:,0,:], torch.ones(20,5) * 1).all()
|
||||
assert torch.eq(lt_sorted.tensor[:,1,:], torch.ones(20,5) * 2).all()
|
||||
assert torch.eq(lt_sorted.tensor[:,2,:], torch.ones(20,5) * 3).all()
|
||||
assert torch.eq(lt_sorted.tensor[:,3,:], torch.ones(20,5) * 4).all()
|
||||
117
tests/test_label_tensor/test_label_tensor_01.py
Normal file
117
tests/test_label_tensor/test_label_tensor_01.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina import LabelTensor
|
||||
|
||||
data = torch.rand((20, 3))
|
||||
labels = ['a', 'b', 'c']
|
||||
|
||||
|
||||
def test_constructor():
|
||||
LabelTensor(data, labels)
|
||||
|
||||
|
||||
def test_wrong_constructor():
|
||||
with pytest.raises(ValueError):
|
||||
LabelTensor(data, ['a', 'b'])
|
||||
|
||||
|
||||
def test_labels():
|
||||
tensor = LabelTensor(data, labels)
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.labels == labels
|
||||
with pytest.raises(ValueError):
|
||||
tensor.labels = labels[:-1]
|
||||
|
||||
|
||||
def test_extract():
|
||||
label_to_extract = ['a', 'c']
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(label_to_extract)
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
assert torch.all(torch.isclose(data[:, 0::2], new))
|
||||
|
||||
|
||||
def test_extract_onelabel():
|
||||
label_to_extract = ['a']
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(label_to_extract)
|
||||
assert new.ndim == 2
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))
|
||||
|
||||
|
||||
def test_wrong_extract():
|
||||
label_to_extract = ['a', 'cc']
|
||||
tensor = LabelTensor(data, labels)
|
||||
with pytest.raises(ValueError):
|
||||
tensor.extract(label_to_extract)
|
||||
|
||||
|
||||
def test_extract_order():
|
||||
label_to_extract = ['c', 'a']
|
||||
tensor = LabelTensor(data, labels)
|
||||
new = tensor.extract(label_to_extract)
|
||||
expected = torch.cat(
|
||||
(data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
|
||||
dim=1)
|
||||
assert new.labels == label_to_extract
|
||||
assert new.shape[1] == len(label_to_extract)
|
||||
assert torch.all(torch.isclose(expected, new))
|
||||
|
||||
|
||||
def test_merge():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_a = tensor.extract('a')
|
||||
tensor_b = tensor.extract('b')
|
||||
tensor_c = tensor.extract('c')
|
||||
|
||||
tensor_bc = tensor_b.append(tensor_c)
|
||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
|
||||
def test_merge2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_b = tensor.extract('b')
|
||||
tensor_c = tensor.extract('c')
|
||||
|
||||
tensor_bc = tensor_b.append(tensor_c)
|
||||
assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))
|
||||
|
||||
|
||||
def test_getitem():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor['a']
|
||||
assert tensor_view.labels == ['a']
|
||||
assert torch.allclose(tensor_view.flatten(), data[:, 0])
|
||||
|
||||
tensor_view = tensor['a', 'c']
|
||||
assert tensor_view.labels == ['a', 'c']
|
||||
assert torch.allclose(tensor_view, data[:, 0::2])
|
||||
|
||||
def test_getitem2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5]
|
||||
assert tensor_view.labels == labels
|
||||
assert torch.allclose(tensor_view, data[:5])
|
||||
|
||||
idx = torch.randperm(tensor.shape[0])
|
||||
tensor_view = tensor[idx]
|
||||
assert tensor_view.labels == labels
|
||||
|
||||
def test_slice():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5, :2]
|
||||
assert tensor_view.labels == labels[:2]
|
||||
assert torch.allclose(tensor_view, data[:5, :2])
|
||||
|
||||
tensor_view2 = tensor[3]
|
||||
|
||||
assert tensor_view2.labels == labels
|
||||
assert torch.allclose(tensor_view2, data[3])
|
||||
|
||||
tensor_view3 = tensor[:, 2]
|
||||
assert tensor_view3.labels == labels[2]
|
||||
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
@@ -27,15 +27,15 @@ def test_grad_scalar_output():
|
||||
grad_tensor_s = grad(tensor_s, inp)
|
||||
true_val = 2*inp
|
||||
assert grad_tensor_s.shape == inp.shape
|
||||
assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [
|
||||
f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in inp.labels[inp.ndim-1]['dof']
|
||||
assert grad_tensor_s.labels == [
|
||||
f'd{tensor_s.labels[0]}d{i}' for i in inp.labels
|
||||
]
|
||||
assert torch.allclose(grad_tensor_s, true_val)
|
||||
|
||||
grad_tensor_s = grad(tensor_s, inp, d=['x', 'y'])
|
||||
assert grad_tensor_s.shape == (20, 2)
|
||||
assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [
|
||||
f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in ['x', 'y']
|
||||
assert grad_tensor_s.labels == [
|
||||
f'd{tensor_s.labels[0]}d{i}' for i in ['x', 'y']
|
||||
]
|
||||
assert torch.allclose(grad_tensor_s, true_val)
|
||||
|
||||
|
||||
@@ -27,50 +27,46 @@ class Poisson(SpatialProblem):
|
||||
|
||||
conditions = {
|
||||
'gamma1':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 1
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 1
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma2':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 0
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 0
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma3':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 1,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 1,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma4':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 0,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 0,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'D':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=my_laplace),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=my_laplace),
|
||||
'data':
|
||||
Condition(input_points=in_, output_points=out_)
|
||||
Condition(input_points=in_, output_points=out_)
|
||||
}
|
||||
|
||||
def poisson_sol(self, pts):
|
||||
return -(torch.sin(pts.extract(['x']) * torch.pi) *
|
||||
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)
|
||||
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi ** 2)
|
||||
|
||||
truth_solution = poisson_sol
|
||||
|
||||
|
||||
# make the problem
|
||||
poisson_problem = Poisson()
|
||||
print(poisson_problem.input_pts)
|
||||
|
||||
def test_discretise_domain():
|
||||
n = 10
|
||||
poisson_problem = Poisson()
|
||||
@@ -83,7 +79,7 @@ def test_discretise_domain():
|
||||
assert poisson_problem.input_pts[b].shape[0] == n
|
||||
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n**2
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n ** 2
|
||||
poisson_problem.discretise_domain(n, 'random', locations=['D'])
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n
|
||||
|
||||
@@ -94,14 +90,15 @@ def test_discretise_domain():
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n
|
||||
|
||||
|
||||
# def test_sampling_few_variables():
|
||||
# n = 10
|
||||
# poisson_problem.discretise_domain(n,
|
||||
# 'grid',
|
||||
# locations=['D'],
|
||||
# variables=['x'])
|
||||
# assert poisson_problem.input_pts['D'].shape[1] == 1
|
||||
# assert poisson_problem._have_sampled_points['D'] is False
|
||||
def test_sampling_few_variables():
|
||||
n = 10
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(n,
|
||||
'grid',
|
||||
locations=['D'],
|
||||
variables=['x'])
|
||||
assert poisson_problem.input_pts['D'].shape[1] == 1
|
||||
assert poisson_problem._have_sampled_points['D'] is False
|
||||
|
||||
|
||||
def test_variables_correct_order_sampling():
|
||||
@@ -117,13 +114,11 @@ def test_variables_correct_order_sampling():
|
||||
variables=['y'])
|
||||
assert poisson_problem.input_pts['D'].labels == sorted(
|
||||
poisson_problem.input_variables)
|
||||
|
||||
poisson_problem.discretise_domain(n,
|
||||
'grid',
|
||||
locations=['D'])
|
||||
assert poisson_problem.input_pts['D'].labels == sorted(
|
||||
poisson_problem.input_variables)
|
||||
|
||||
poisson_problem.discretise_domain(n,
|
||||
'grid',
|
||||
locations=['D'],
|
||||
@@ -140,8 +135,8 @@ def test_add_points():
|
||||
poisson_problem.discretise_domain(0,
|
||||
'random',
|
||||
locations=['D'],
|
||||
variables=['x','y'])
|
||||
new_pts = LabelTensor(torch.tensor([[0.5,-0.5]]),labels=['x','y'])
|
||||
variables=['x', 'y'])
|
||||
new_pts = LabelTensor(torch.tensor([[0.5, -0.5]]), labels=['x', 'y'])
|
||||
poisson_problem.add_points({'D': new_pts})
|
||||
assert torch.isclose(poisson_problem.input_pts['D'].extract('x'),new_pts.extract('x'))
|
||||
assert torch.isclose(poisson_problem.input_pts['D'].extract('y'),new_pts.extract('y'))
|
||||
assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x'))
|
||||
assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y'))
|
||||
|
||||
Reference in New Issue
Block a user