@@ -246,7 +246,7 @@ class FourierFeatureEmbedding(torch.nn.Module):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# compute random matrix multiplication
|
||||
out = torch.mm(x, self._matrix)
|
||||
out = torch.mm(x, self._matrix.to(device=x.device, dtype=x.dtype))
|
||||
# return embedding
|
||||
return torch.cat(
|
||||
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
|
||||
|
||||
Reference in New Issue
Block a user