* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
24 lines
745 B
Python
24 lines
745 B
Python
""" Module """
|
|
import torch
|
|
from .equation import Equation
|
|
|
|
class SystemEquation(Equation):
|
|
|
|
def __init__(self, list_equation):
|
|
if not isinstance(list_equation, list):
|
|
raise TypeError('list_equation must be a list of functions')
|
|
|
|
self.equations = []
|
|
for i, equation in enumerate(list_equation):
|
|
if not callable(equation):
|
|
raise TypeError('list_equation must be a list of functions')
|
|
|
|
self.equations.append(Equation(equation))
|
|
|
|
def residual(self, input_, output_):
|
|
return torch.mean(
|
|
torch.stack([
|
|
equation.residual(input_, output_)
|
|
for equation in self.equations
|
|
]),
|
|
dim=0) |