equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

10
pina/equation/equation.py Normal file
View File

@@ -0,0 +1,10 @@
""" Module """
from .equation_interface import EquationInterface
class Equation(EquationInterface):
def __init__(self, equation):
self.__equation = equation
def residual(self, input_, output_):
return self.__equation(input_, output_)

View File

@@ -0,0 +1,37 @@
""" Module """
from .equation import Equation
from ..operators import grad, div, nabla
class FixedValue(Equation):
def __init__(self, value, components=None):
def equation(input_, output_):
if components is None:
return output_ - value
return output_.extract(components) - value
super().__init__(equation)
class FixedGradient(Equation):
def __init__(self, value, components=None, d=None):
def equation(input_, output_):
return grad(output_, input_, components=components, d=d) - value
super().__init__(equation)
class FixedFlux(Equation):
def __init__(self, value, components=None, d=None):
def equation(input_, output_):
return div(output_, input_, components=components, d=d) - value
super().__init__(equation)
class Laplace(Equation):
def __init__(self, components=None, d=None):
def equation(input_, output_):
return nabla(output_, input_, components=components, d=d)
super().__init__(equation)

View File

@@ -0,0 +1,13 @@
""" Module for EquationInterface class """
from abc import ABCMeta, abstractmethod
class EquationInterface(metaclass=ABCMeta):
"""
The abstract `AbstractProblem` class. All the class defining a PINA Problem
should be inheritied from this class.
In the definition of a PINA problem, the fundamental elements are:
the output variables, the condition(s), and the domain(s) where the
conditions are applied.
"""

View File

@@ -0,0 +1,24 @@
""" Module """
import torch
from .equation import Equation
class SystemEquation(Equation):
def __init__(self, list_equation):
if not isinstance(list_equation, list):
raise TypeError('list_equation must be a list of functions')
self.equations = []
for i, equation in enumerate(list_equation):
if not callable(equation):
raise TypeError('list_equation must be a list of functions')
self.equations.append(Equation(equation))
def residual(self, input_, output_):
return torch.mean(
torch.stack([
equation.residual(input_, output_)
for equation in self.equations
]),
dim=0)