Implementation of DataLoader and DataModule (#383)
Refactoring for 0.2 * Data module, data loader and dataset * Refactor LabelTensor * Refactor solvers Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
dd43c8304c
commit
a27bd35443
@@ -119,6 +119,7 @@ class LowRankBlock(torch.nn.Module):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract basis
|
||||
coords = coords.as_subclass(torch.Tensor)
|
||||
basis = self._basis(coords)
|
||||
# reshape [B, N, D, 2*rank]
|
||||
shape = list(basis.shape[:-1]) + [-1, 2 * self.rank]
|
||||
|
||||
@@ -29,7 +29,8 @@ class Network(torch.nn.Module):
|
||||
# check model consistency
|
||||
check_consistency(model, nn.Module)
|
||||
check_consistency(input_variables, str)
|
||||
check_consistency(output_variables, str)
|
||||
if output_variables is not None:
|
||||
check_consistency(output_variables, str)
|
||||
|
||||
self._model = model
|
||||
self._input_variables = input_variables
|
||||
@@ -67,16 +68,15 @@ class Network(torch.nn.Module):
|
||||
# in case `input_variables = []` all points are used
|
||||
if self._input_variables:
|
||||
x = x.extract(self._input_variables)
|
||||
|
||||
# extract features and append
|
||||
for feature in self._extra_features:
|
||||
x = x.append(feature(x))
|
||||
|
||||
# perform forward pass + converting to LabelTensor
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self._output_variables
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
output = self._model(x)
|
||||
if self._output_variables is not None:
|
||||
output = LabelTensor(output, self._output_variables)
|
||||
|
||||
return output
|
||||
|
||||
@@ -97,15 +97,9 @@ class Network(torch.nn.Module):
|
||||
This function does not extract the input variables, all the variables
|
||||
are used for both tensors. Output variables are correctly applied.
|
||||
"""
|
||||
# convert LabelTensor s to torch.Tensor s
|
||||
x = list(map(lambda x: x.as_subclass(torch.Tensor), x))
|
||||
|
||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self._output_variables
|
||||
|
||||
output = LabelTensor(self._model(x.tensor), self._output_variables)
|
||||
return output
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user