diff --git a/pina/model/layers/embedding.py b/pina/model/layers/embedding.py index c9f41da..4248136 100644 --- a/pina/model/layers/embedding.py +++ b/pina/model/layers/embedding.py @@ -246,7 +246,7 @@ class FourierFeatureEmbedding(torch.nn.Module): :rtype: torch.Tensor """ # compute random matrix multiplication - out = torch.mm(x, self._matrix) + out = torch.mm(x, self._matrix.to(device=x.device, dtype=x.dtype)) # return embedding return torch.cat( [torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],