body = c_quizzes.repeat(n, 1)
             if n < c_quiz_multiplier:
                 tail = c_quizzes[
-                    torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+                    torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[
+                        : nb_samples // 2 - body.size(0)
+                    ]
                 ]
                 c_quizzes = torch.cat([body, tail], dim=0)
             else:
                 c_quizzes = body
 
         if c_quizzes.size(0) > nb_samples // 2:
-            i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+            i = torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[
+                : nb_samples // 2
+            ]
             c_quizzes = c_quizzes[i]
 
         w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))