@@ -246,7 +246,7 @@ class FourierFeatureEmbedding(torch.nn.Module):
|
|||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
# compute random matrix multiplication
|
# compute random matrix multiplication
|
||||||
out = torch.mm(x, self._matrix)
|
out = torch.mm(x, self._matrix.to(device=x.device, dtype=x.dtype))
|
||||||
# return embedding
|
# return embedding
|
||||||
return torch.cat(
|
return torch.cat(
|
||||||
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
|
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],
|
||||||
|
|||||||
Reference in New Issue
Block a user