gpu input data support (#73)

This commit is contained in:
Nicola Demo
2023-03-10 13:57:31 +01:00
committed by GitHub
parent c5b2596910
commit 465718aead
2 changed files with 37 additions and 34 deletions

View File

@@ -241,7 +241,7 @@ class PINN(object):
pts = condition.input_points.to(
dtype=self.dtype, device=self.device)
predicted = self.model(pts)
residuals = predicted - condition.output_points
residuals = predicted - condition.output_points.to(device=self.device, dtype=self.dtype) # TODO fix
local_loss = (
condition.data_weight*self._compute_norm(residuals))
single_loss.append(local_loss)