Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 3 Sep 2024 06:51:14 +0000 (08:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 3 Sep 2024 06:51:14 +0000 (08:51 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 4860073..9b2282f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -113,7 +113,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5)
 
 parser.add_argument("--temperature_cold", type=float, default=1)
 
-parser.add_argument("--prompt_noise", type=float, default=0.05)
+parser.add_argument("--prompt_noise", type=float, default=0.0)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -298,7 +298,6 @@ quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
     batch_size=args.inference_batch_size,
     result_dir=args.result_dir,
-    prompt_noise=args.prompt_noise,
     logger=log_string,
     device=main_device,
 )
@@ -1098,7 +1097,7 @@ def model_ae_proba_solutions(model, input, log_proba=False):
 nb_diffusion_iterations = 25
 
 
-def degrade_input(input, mask_generate, nb_iterations):
+def degrade_input_to_generate(input, mask_generate, nb_iterations):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
@@ -1116,20 +1115,36 @@ def degrade_input(input, mask_generate, nb_iterations):
     return result
 
 
-def targets_and_prediction(model, input, mask_generate):
+def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0):
     d = deterministic(mask_generate)
+
     probs_iterations = 0.1 ** torch.linspace(
         0, 1, args.nb_diffusion_iterations, device=input.device
     )
+
     probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
     probs_iterations = probs_iterations.expand(input.size(0), -1)
     dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
-    N0 = dist.sample()
-    N1 = N0 + 1
-    N0 = (1 - d) * N0
-    N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
 
-    targets, input = degrade_input(input, mask_generate, (0 * N1, N1))
+    # N0 = dist.sample()
+    # N1 = N0 + 1
+    # N0 = (1 - d) * N0
+    # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
+
+    N0 = input.new_zeros(input.size(0))
+    N1 = dist.sample() + 1
+
+    targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
+
+    if prompt_noise > 0:
+        mask_prompt_noise = (
+            torch.rand(input.size(), device=input.device()) <= prompt_noise
+        ).long()
+        noise = torch.randint(
+            quiz_machine.problem.nb_colors, input.size(), device=input.device
+        )
+        noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
+        input = (1 - mask_generate) * noisy_input + mask_generate * input
 
     input_with_mask = NTC_channel_cat(input, mask_generate)
     logits = model(input_with_mask)
@@ -1250,9 +1265,7 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_
 ######################################################################
 
 
-def one_ae_epoch(
-    model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device
-):
+def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
     model.train().to(local_device)
     optimizer_to(model.optimizer, local_device)
 
@@ -1273,7 +1286,9 @@ def one_ae_epoch(
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        targets, logits = targets_and_prediction(model, input, mask_generate)
+        targets, logits = targets_and_prediction(
+            model, input, mask_generate, prompt_noise=args.prompt_noise
+        )
 
         loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
         acc_train_loss += loss.item() * input.size(0)
@@ -1572,7 +1587,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # --------------------------------------------------------------------
 
-    # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
+    # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device)
     # exit(0)
 
     log_string(f"{time_train=} {time_c_quizzes=}")
@@ -1612,7 +1627,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         t = threading.Thread(
             target=one_ae_epoch,
             daemon=True,
-            args=(model, models, quiz_machine, n_epoch, c_quizzes, gpu),
+            args=(model, quiz_machine, n_epoch, c_quizzes, gpu),
         )
 
         threads.append(t)
index ce4d4f5..f1eb9db 100755 (executable)
@@ -67,7 +67,6 @@ class QuizMachine:
         problem,
         batch_size,
         result_dir,
-        prompt_noise,
         logger,
         device=torch.device("cpu"),
     ):
@@ -79,7 +78,6 @@ class QuizMachine:
         self.logger = logger
         self.prompt_len = None
         self.answer_len = None
-        self.prompt_noise = prompt_noise
 
         # quad_order, quad_generate, quad_noise, quad_loss
         self.train_structures = [
@@ -186,13 +184,6 @@ class QuizMachine:
             quad_order, quad_generate, quad_noise, quad_loss = s
             i = order_ids == j
             quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order)
-            if self.prompt_noise > 0.0:
-                quizzes[i] = self.problem.inject_noise(
-                    quizzes[i],
-                    self.prompt_noise,
-                    quad_order=quad_order,
-                    quad_noise=quad_noise,
-                )
             quiz_mask_generate[i] = self.make_quiz_mask(
                 quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
             )
@@ -335,11 +326,6 @@ class QuizMachine:
             device=device,
         )
 
-        # if self.prompt_noise > 0.0 and quad_noise is not None:
-        # c_quizzes = self.problem.inject_noise(
-        # c_quizzes, self.prompt_noise, quad_order=quad_order, quad_noise=quad_noise
-        # )
-
         with torch.autograd.no_grad():
             t = model.training
             model.eval()