equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
24
pina/equation/system_equation.py
Normal file
24
pina/equation/system_equation.py
Normal file
@@ -0,0 +1,24 @@
|
||||
""" Module """
|
||||
import torch
|
||||
from .equation import Equation
|
||||
|
||||
class SystemEquation(Equation):
|
||||
|
||||
def __init__(self, list_equation):
|
||||
if not isinstance(list_equation, list):
|
||||
raise TypeError('list_equation must be a list of functions')
|
||||
|
||||
self.equations = []
|
||||
for i, equation in enumerate(list_equation):
|
||||
if not callable(equation):
|
||||
raise TypeError('list_equation must be a list of functions')
|
||||
|
||||
self.equations.append(Equation(equation))
|
||||
|
||||
def residual(self, input_, output_):
|
||||
return torch.mean(
|
||||
torch.stack([
|
||||
equation.residual(input_, output_)
|
||||
for equation in self.equations
|
||||
]),
|
||||
dim=0)
|
||||
Reference in New Issue
Block a user