equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

@@ -0,0 +1,24 @@
""" Module """
import torch
from .equation import Equation
class SystemEquation(Equation):
def __init__(self, list_equation):
if not isinstance(list_equation, list):
raise TypeError('list_equation must be a list of functions')
self.equations = []
for i, equation in enumerate(list_equation):
if not callable(equation):
raise TypeError('list_equation must be a list of functions')
self.equations.append(Equation(equation))
def residual(self, input_, output_):
return torch.mean(
torch.stack([
equation.residual(input_, output_)
for equation in self.equations
]),
dim=0)