equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
@@ -42,6 +42,7 @@ class Network(torch.nn.Module):
|
||||
output_variables, extra_features=None):
|
||||
super().__init__()
|
||||
|
||||
print('HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH')
|
||||
if extra_features is None:
|
||||
extra_features = []
|
||||
|
||||
@@ -49,6 +50,7 @@ class Network(torch.nn.Module):
|
||||
self._model = model
|
||||
self._input_variables = input_variables
|
||||
self._output_variables = output_variables
|
||||
print(output_variables)
|
||||
|
||||
# check model and input/output
|
||||
self._check_consistency()
|
||||
@@ -59,10 +61,11 @@ class Network(torch.nn.Module):
|
||||
:raises ValueError: Error in constructing the PINA network
|
||||
"""
|
||||
try:
|
||||
tmp = torch.rand((10, len(self._input_variables)))
|
||||
tmp = LabelTensor(tmp, self._input_variables)
|
||||
tmp = self.forward(tmp) # trying a forward pass
|
||||
tmp = LabelTensor(tmp, self._output_variables)
|
||||
pass
|
||||
# tmp = torch.rand((10, len(self._input_variables)))
|
||||
# tmp = LabelTensor(tmp, self._input_variables)
|
||||
# tmp = self.forward(tmp) # trying a forward pass
|
||||
# tmp = LabelTensor(tmp, self._output_variables)
|
||||
except:
|
||||
raise ValueError('Error in constructing the PINA network.'
|
||||
' Check compatibility of input/output'
|
||||
|
||||
Reference in New Issue
Block a user