Fix SupervisedSolver GPU bug and implement GraphSolver (#346)
* Fix some bugs * Solve bug with GPU and model_summary parameters in SupervisedSolver class * Implement GraphSolver class * Fix Tutorial 5
This commit is contained in:
committed by
Nicola Demo
parent
30f865d912
commit
2be57944ba
@@ -4,6 +4,7 @@ from .sample_dataset import SamplePointDataset
|
||||
from .data_dataset import DataPointDataset
|
||||
from .pina_batch import Batch
|
||||
|
||||
|
||||
class SamplePointLoader:
|
||||
"""
|
||||
This class is used to create a dataloader to use during the training.
|
||||
@@ -95,7 +96,7 @@ class SamplePointLoader:
|
||||
self.batch_output_pts = torch.tensor_split(
|
||||
dataset.output_pts, batch_num
|
||||
)
|
||||
print(input_labels)
|
||||
#print(input_labels)
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_input_pts[i].labels = input_labels
|
||||
self.batch_output_pts[i].labels = output_labels
|
||||
@@ -161,7 +162,6 @@ class SamplePointLoader:
|
||||
self.batch_input_pts,
|
||||
self.batch_output_pts,
|
||||
self.batch_data_conditions)
|
||||
print(batch.input.labels)
|
||||
|
||||
self.batches.append(batch)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user