gpu input data support (#73)
This commit is contained in:
@@ -241,7 +241,7 @@ class PINN(object):
|
||||
pts = condition.input_points.to(
|
||||
dtype=self.dtype, device=self.device)
|
||||
predicted = self.model(pts)
|
||||
residuals = predicted - condition.output_points
|
||||
residuals = predicted - condition.output_points.to(device=self.device, dtype=self.dtype) # TODO fix
|
||||
local_loss = (
|
||||
condition.data_weight*self._compute_norm(residuals))
|
||||
single_loss.append(local_loss)
|
||||
|
||||
Reference in New Issue
Block a user