Generic DeepONet (#68)

* generic deeponet
* example for generic deeponet
* adapt tests to new interface
This commit is contained in:
Francesco Andreuzzi
2023-01-11 12:07:19 +01:00
committed by GitHub
parent e227700fbc
commit 7ce080fd85
5 changed files with 280 additions and 37 deletions

View File

@@ -1,5 +1,7 @@
"""Utils module"""
from functools import reduce
import types
import torch
from torch.utils.data import DataLoader, default_collate, ConcatDataset
@@ -85,6 +87,17 @@ def torch_lhs(n, dim):
return samples
def is_function(f):
"""
Checks whether the given object `f` is a function or lambda.
:param object f: The object to be checked.
:return: `True` if `f` is a function, `False` otherwise.
:rtype: bool
"""
return type(f) == types.FunctionType or type(f) == types.LambdaType
class PinaDataset():
def __init__(self, pinn) -> None:
@@ -144,4 +157,4 @@ class PinaDataset():
return {self._location: tensor}
def __len__(self):
return self._len
return self._len