From 6c1e5af6fdb00d97136df7ac6ae89bd51fb6ccf7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 16:50:13 +0200 Subject: [PATCH] Update. --- main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/main.py b/main.py index 464b217..a357687 100755 --- a/main.py +++ b/main.py @@ -483,6 +483,10 @@ def prioritized_rand(low): def ae_generate(model, nb, local_device=main_device): model.eval().to(local_device) + # We loop through the iterations first and through the + # mini-batches second so that we keep only the samples that have + # not stabilized + all_input = quiz_machine.pure_noise(nb, local_device) all_masks = all_input.new_full(all_input.size(), 1) all_changed = torch.full((all_input.size(0),), True, device=all_input.device) -- 2.39.5