Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 14:31:50 +0000 (16:31 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 14:31:50 +0000 (16:31 +0200)
main.py

diff --git a/main.py b/main.py
index c6eedfb..464b217 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -480,41 +480,42 @@ def prioritized_rand(low):
     return y
 
 
-def ae_generate(model, nb, local_device=main_device, desc="generate"):
+def ae_generate(model, nb, local_device=main_device):
     model.eval().to(local_device)
 
     all_input = quiz_machine.pure_noise(nb, local_device)
     all_masks = all_input.new_full(all_input.size(), 1)
+    all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
 
-    src = zip(
-        all_input.split(args.physical_batch_size),
-        all_masks.split(args.physical_batch_size),
-    )
+    for it in range(args.diffusion_nb_iterations):
+        if not all_changed.any():
+            break
 
-    if desc is not None:
-        src = tqdm.tqdm(
-            src,
-            dynamic_ncols=True,
-            desc="generate",
-            total=all_input.size(0) // args.physical_batch_size,
+        sub_input = all_input[all_changed].clone()
+        sub_masks = all_masks[all_changed].clone()
+        sub_changed = all_changed[all_changed].clone()
+
+        src = zip(
+            sub_input.split(args.physical_batch_size),
+            sub_masks.split(args.physical_batch_size),
+            sub_changed.split(args.physical_batch_size),
         )
 
-    for input, masks in src:
-        changed = True
-        for it in range(args.diffusion_nb_iterations):
+        for input, masks, changed in src:
             with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                 logits = model(input * 2 + masks)
             dist = torch.distributions.categorical.Categorical(logits=logits)
             output = dist.sample()
-
             r = prioritized_rand(input != output)
             mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
             update = (1 - mask_changes) * input + mask_changes * output
-            if update.equal(input):
-                break
-            else:
-                changed = changed & (update != input).max(dim=1).values
-                input[changed] = update[changed]
+            changed[...] = changed & (update != input).max(dim=1).values
+            input[...] = update
+
+        a = all_changed.clone()
+        all_input[a] = sub_input
+        all_masks[a] = sub_masks
+        all_changed[a] = sub_changed
 
     return all_input
 
@@ -709,10 +710,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
         generator_id = model.id
 
         c_quizzes = ae_generate(
-            model=model,
-            nb=args.physical_batch_size,
-            local_device=local_device,
-            desc=None,
+            model=model, nb=args.physical_batch_size * 10, local_device=local_device
         )
 
         # Select the ones that are solved properly by some models and
@@ -847,11 +845,7 @@ if args.quizzes is not None:
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            result = ae_generate(
-                model,
-                (1 - mask_generate) * quizzes,
-                mask_generate,
-            )
+            result = ae_generate(model, (1 - mask_generate) * quizzes, mask_generate)
             record.append(result)
 
     result = torch.cat(record, dim=0)