Network handles forward for all solvers
This commit is contained in:
committed by
Nicola Demo
parent
4844640727
commit
c90301c204
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
|
||||
def __init__(self, model, extra_features=None):
|
||||
def __init__(self, model, input_variables, output_variables, extra_features=None):
|
||||
"""
|
||||
Network class with standard forward method
|
||||
and possibility to pass extra features. This
|
||||
@@ -14,6 +15,10 @@ class Network(torch.nn.Module):
|
||||
|
||||
:param model: The torch model to convert in a PINA model.
|
||||
:type model: torch.nn.Module
|
||||
:param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the
|
||||
type of domain (spatial, temporal, and parameter).
|
||||
:param list(str) output_variables: The output variables of the :class:`AbstractProblem`, whose type depends on the
|
||||
problem setting.
|
||||
:param extra_features: List of torch models to augment the input, defaults to None.
|
||||
:type extra_features: list(torch.nn.Module)
|
||||
"""
|
||||
@@ -21,7 +26,12 @@ class Network(torch.nn.Module):
|
||||
|
||||
# check model consistency
|
||||
check_consistency(model, nn.Module)
|
||||
check_consistency(input_variables, str)
|
||||
check_consistency(output_variables, str)
|
||||
|
||||
self._model = model
|
||||
self._input_variables = input_variables
|
||||
self._output_variables = output_variables
|
||||
|
||||
# check consistency and assign extra fatures
|
||||
if extra_features is None:
|
||||
@@ -46,14 +56,55 @@ class Network(torch.nn.Module):
|
||||
:param torch.Tensor x: Input of the network.
|
||||
:return torch.Tensor: Output of the network.
|
||||
"""
|
||||
# extract torch.Tensor from corresponding label
|
||||
# in case `input_variables = []` all points are used
|
||||
if self._input_variables:
|
||||
x = x.extract(self._input_variables)
|
||||
|
||||
# extract features and append
|
||||
for feature in self._extra_features:
|
||||
x = x.append(feature(x))
|
||||
# perform forward pass
|
||||
return self._model(x)
|
||||
|
||||
# convert LabelTensor to torch.Tensor
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
|
||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self._output_variables
|
||||
|
||||
return output
|
||||
|
||||
def forward_map(self, x):
|
||||
"""
|
||||
Forward method for Network class when the input is
|
||||
a tuple. This class implements the standard forward method,
|
||||
and it adds the possibility to pass extra features.
|
||||
All the PINA models ``forward`` s are overriden
|
||||
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
|
||||
extraction.
|
||||
|
||||
:param list (torch.Tensor) | tuple(torch.Tensor) x: Input of the network.
|
||||
:return torch.Tensor: Output of the network.
|
||||
|
||||
.. note::
|
||||
This function does not extract the input variables, all the variables
|
||||
are used for both tensors. Output variables are correctly applied.
|
||||
"""
|
||||
# convert LabelTensor s to torch.Tensor s
|
||||
x = list(map(lambda x: x.as_subclass(torch.Tensor), x))
|
||||
|
||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self._output_variables
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
def torchmodel(self):
|
||||
return self._model
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user