Fix SupervisedSolver GPU bug and implement GraphSolver (#346)

* Fix some bugs
* Solve bug with GPU and model_summary parameters in SupervisedSolver class
* Implement GraphSolver class
* Fix Tutorial 5
This commit is contained in:
FilippoOlivo
2024-09-21 18:55:57 +02:00
committed by Nicola Demo
parent 30f865d912
commit 2be57944ba
10 changed files with 334 additions and 164 deletions

View File

@@ -425,7 +425,7 @@ class LabelTensor(torch.Tensor):
raise NotImplementedError
labels = [tensor.labels for tensor in tensors]
print(labels)
def requires_grad_(self, mode=True):
lt = super().requires_grad_(mode)
@@ -436,7 +436,6 @@ class LabelTensor(torch.Tensor):
def dtype(self):
return super().dtype
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
@@ -447,7 +446,6 @@ class LabelTensor(torch.Tensor):
new.data = tmp.data
return new
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see