22 lines
602 B
Python
22 lines
602 B
Python
import torch
|
|
from pina.model.block import PODBlock
|
|
from pina.model import FeedForward
|
|
|
|
class PODNN(torch.nn.Module):
|
|
def __init__(self, pod_rank, layers, func=torch.nn.Softplus):
|
|
super().__init__()
|
|
self.pod = PODBlock(pod_rank, scale_coefficients=False)
|
|
self.nn = FeedForward(
|
|
input_dimensions=3,
|
|
output_dimensions=pod_rank,
|
|
layers=layers,
|
|
func=func,
|
|
)
|
|
|
|
|
|
def forward(self, p):
|
|
coefficients = self.nn(p)
|
|
return self.pod.expand(coefficients)
|
|
|
|
def fit_pod(self, x):
|
|
self.pod.fit(x) |