equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

@@ -5,12 +5,10 @@ from pina import LabelTensor, Condition, CartesianDomain, PINN
from pina.problem import SpatialProblem
from pina.model import FeedForward
from pina.operators import nabla
from pina.equation.equation_factory import FixedValue
example_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
def example_dirichlet(input_, output_):
value = 0.0
return output_.extract(['u']) - value
example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
@@ -21,22 +19,22 @@ def test_init_inputoutput():
with pytest.raises(TypeError):
Condition(input_points=3., output_points='example')
with pytest.raises(TypeError):
Condition(input_points=example_domain, output_points=example_dirichlet)
Condition(input_points=example_domain, output_points=example_domain)
def test_init_locfunc():
Condition(location=example_domain, function=example_dirichlet)
Condition(location=example_domain, equation=FixedValue(0.0))
with pytest.raises(ValueError):
Condition(example_domain, example_dirichlet)
Condition(example_domain, FixedValue(0.0))
with pytest.raises(TypeError):
Condition(location=3., function='example')
Condition(location=3., equation='example')
with pytest.raises(TypeError):
Condition(location=example_input_pts, function=example_output_pts)
Condition(location=example_input_pts, equation=example_output_pts)
def test_init_inputfunc():
Condition(input_points=example_input_pts, function=example_dirichlet)
Condition(input_points=example_input_pts, equation=FixedValue(0.0))
with pytest.raises(ValueError):
Condition(example_domain, example_dirichlet)
Condition(example_domain, FixedValue(0.0))
with pytest.raises(TypeError):
Condition(input_points=3., function='example')
Condition(input_points=3., equation='example')
with pytest.raises(TypeError):
Condition(input_points=example_domain, function=example_output_pts)
Condition(input_points=example_domain, equation=example_output_pts)