Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 11:09:55 +0000 (13:09 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 11:09:55 +0000 (13:09 +0200)
main.py

diff --git a/main.py b/main.py
index 4c30771..21c609c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -819,7 +819,7 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
 
     x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
 
-    return result
+    return x_t_minus_1
 
 
 ######################################################################
@@ -888,7 +888,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
 
         hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
 
-        hat_x_t_minus_1 = one_iteration_prediction * x_0 + (
+        hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
             1 - one_iteration_prediction
         ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
 
@@ -913,7 +913,7 @@ def model_ae_proba_solutions(model, input, log_proba=False):
 
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise
@@ -939,7 +939,7 @@ def model_ae_argmax_nb_disagreements(model, input):
         nb_disagreements = 0
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise
@@ -966,7 +966,7 @@ def model_ae_argmax_predictions(model, input):
     for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)):
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise