gpu fourierembedding (#313)

Adding gpu to fourier embedding
This commit is contained in:
Dario Coscia
2024-06-27 17:36:54 +02:00
committed by GitHub
parent ba79984be2
commit f9316e359a

View File

@@ -246,7 +246,7 @@ class FourierFeatureEmbedding(torch.nn.Module):
:rtype: torch.Tensor
"""
# compute random matrix multiplication
out = torch.mm(x, self._matrix)
out = torch.mm(x, self._matrix.to(device=x.device, dtype=x.dtype))
# return embedding
return torch.cat(
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],