Files
thermal-conduction-ml/model/pod_nn.py
2025-09-19 12:07:31 +02:00

22 lines
602 B
Python

import torch
from pina.model.block import PODBlock
from pina.model import FeedForward
class PODNN(torch.nn.Module):
def __init__(self, pod_rank, layers, func=torch.nn.Softplus):
super().__init__()
self.pod = PODBlock(pod_rank, scale_coefficients=False)
self.nn = FeedForward(
input_dimensions=3,
output_dimensions=pod_rank,
layers=layers,
func=func,
)
def forward(self, p):
coefficients = self.nn(p)
return self.pod.expand(coefficients)
def fit_pod(self, x):
self.pod.fit(x)