Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -119,6 +119,7 @@ class LowRankBlock(torch.nn.Module):
:rtype: torch.Tensor
"""
# extract basis
coords = coords.as_subclass(torch.Tensor)
basis = self._basis(coords)
# reshape [B, N, D, 2*rank]
shape = list(basis.shape[:-1]) + [-1, 2 * self.rank]

View File

@@ -29,7 +29,8 @@ class Network(torch.nn.Module):
# check model consistency
check_consistency(model, nn.Module)
check_consistency(input_variables, str)
check_consistency(output_variables, str)
if output_variables is not None:
check_consistency(output_variables, str)
self._model = model
self._input_variables = input_variables
@@ -67,16 +68,15 @@ class Network(torch.nn.Module):
# in case `input_variables = []` all points are used
if self._input_variables:
x = x.extract(self._input_variables)
# extract features and append
for feature in self._extra_features:
x = x.append(feature(x))
# perform forward pass + converting to LabelTensor
output = self._model(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self._output_variables
x = x.as_subclass(torch.Tensor)
output = self._model(x)
if self._output_variables is not None:
output = LabelTensor(output, self._output_variables)
return output
@@ -97,15 +97,9 @@ class Network(torch.nn.Module):
This function does not extract the input variables, all the variables
are used for both tensors. Output variables are correctly applied.
"""
# convert LabelTensor s to torch.Tensor s
x = list(map(lambda x: x.as_subclass(torch.Tensor), x))
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self._model(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self._output_variables
output = LabelTensor(self._model(x.tensor), self._output_variables)
return output
@property