Add Graph support in Dataset and Dataloader
This commit is contained in:
committed by
Nicola Demo
parent
eb146ea2ea
commit
ccc5f5a322
@@ -130,14 +130,13 @@ class SupervisedSolver(SolverInterface):
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} works only in data-driven mode.")
|
||||
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
input_pts.labels = pts.labels
|
||||
output_pts.labels = out.labels
|
||||
|
||||
loss = (self.loss_data(input_pts=input_pts, output_pts=output_pts))
|
||||
loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
|
||||
Reference in New Issue
Block a user