Generic DeepONet (#68)
* generic deeponet * example for generic deeponet * adapt tests to new interface
This commit is contained in:
committed by
GitHub
parent
e227700fbc
commit
7ce080fd85
@@ -1,5 +1,7 @@
|
||||
"""Utils module"""
|
||||
from functools import reduce
|
||||
import types
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, default_collate, ConcatDataset
|
||||
|
||||
@@ -85,6 +87,17 @@ def torch_lhs(n, dim):
|
||||
return samples
|
||||
|
||||
|
||||
def is_function(f):
|
||||
"""
|
||||
Checks whether the given object `f` is a function or lambda.
|
||||
|
||||
:param object f: The object to be checked.
|
||||
:return: `True` if `f` is a function, `False` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return type(f) == types.FunctionType or type(f) == types.LambdaType
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
|
||||
def __init__(self, pinn) -> None:
|
||||
@@ -144,4 +157,4 @@ class PinaDataset():
|
||||
return {self._location: tensor}
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
return self._len
|
||||
|
||||
Reference in New Issue
Block a user