equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

@@ -42,6 +42,7 @@ class Network(torch.nn.Module):
output_variables, extra_features=None):
super().__init__()
print('HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH')
if extra_features is None:
extra_features = []
@@ -49,6 +50,7 @@ class Network(torch.nn.Module):
self._model = model
self._input_variables = input_variables
self._output_variables = output_variables
print(output_variables)
# check model and input/output
self._check_consistency()
@@ -59,10 +61,11 @@ class Network(torch.nn.Module):
:raises ValueError: Error in constructing the PINA network
"""
try:
tmp = torch.rand((10, len(self._input_variables)))
tmp = LabelTensor(tmp, self._input_variables)
tmp = self.forward(tmp) # trying a forward pass
tmp = LabelTensor(tmp, self._output_variables)
pass
# tmp = torch.rand((10, len(self._input_variables)))
# tmp = LabelTensor(tmp, self._input_variables)
# tmp = self.forward(tmp) # trying a forward pass
# tmp = LabelTensor(tmp, self._output_variables)
except:
raise ValueError('Error in constructing the PINA network.'
' Check compatibility of input/output'