Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 10:59:47 +0000 (12:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 10:59:47 +0000 (12:59 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 4326491..4375985 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -452,7 +452,7 @@ c_quizzes_procedure = [
 ######################################################################
 
 
-def save_additional_results(model, models, c_quizzes_procedure):
+def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
     # Save generated quizzes with the successive generation steps
 
     recorder = []
@@ -592,6 +592,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
                         proba_other_solutions = model_proba_solutions(
                             model, solved_c_quizzes[s]
                         )
+
+                        # proba_other_solutions += torch.rand(proba_other_solutions.size()) * 1e-6
+
                         proba_other_solutions[dont_get_this_quiz] = -1
                         # print(
                         # f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n"
@@ -945,8 +948,101 @@ if args.dirty_debug:
     args.nb_new_c_quizzes_for_train = 100
     args.nb_new_c_quizzes_for_test = 10
 
-if args.test == "gen":
-    save_additional_results(model, models, c_quizzes_procedure)
+######################################################################
+######################################################################
+
+
+class Recorder(nn.Module):
+    def __init__(self, tape):
+        super().__init__()
+        self.tape = tape
+
+    def forward(self, input):
+        self.tape.append(input)
+        return input
+
+
+if args.test == "mlp":
+    model = models[0]
+    tape_input, tape_output = [], []
+    L = len(model.trunk)
+    model.trunk.insert(L // 2 + 1, Recorder(tape_output))
+    model.trunk.insert(L // 2, Recorder(tape_input))
+
+    print(model.trunk)
+    train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+
+    with torch.autograd.no_grad():
+        model.to(main_device).eval()
+        for input in train_input.split(args.batch_size):
+            input = input.to(main_device)
+            output = model(mygpt.BracketedSequence(input)).x
+
+    train_input = torch.cat([bs.x for bs in tape_input], dim=0)
+    train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
+
+    print(f"{train_input.size()=} {train_targets.size()=}")
+
+    exit(0)
+
+######################################################################
+######################################################################
+
+if args.test == "reject":
+    record = []
+
+    c_quizzes_procedure = [
+        (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot),
+        (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
+        (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
+        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+        (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
+        (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
+        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+        (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
+        (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
+        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+    ]
+
+    while sum([x.size(0) for x in record]) < 64:
+        model = models[torch.randint(len(models), (1,)).item()]
+        c_quizzes = quiz_machine.generate_c_quizzes(
+            64,
+            model_for_generation=model,
+            procedure=c_quizzes_procedure,
+        )
+
+        p = quiz_machine.models_logprobas(
+            model,
+            c_quizzes,
+            ("A", "f_A", "B", "f_B"),
+            (1, 1, 1, 1),
+            temperature=1,
+        ).exp()
+
+        p_hot = quiz_machine.models_logprobas(
+            model,
+            c_quizzes,
+            ("A", "f_A", "B", "f_B"),
+            (1, 1, 1, 1),
+            temperature=args.temperature_hot,
+        ).exp()
+
+        to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p
+        record.append(c_quizzes[to_keep])
+
+        print("NB_KEPT", sum([x.size(0) for x in record]))
+
+    filename = f"sampling_with_rejection.png"
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=c_quizzes,
+    )
+
+    log_string(f"wrote {filename}")
+
     exit(0)
 
 ######################################################################
@@ -1018,7 +1114,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         log_string(f"wrote {filename}")
 
     for model in weakest_models:
-        save_additional_results(model, models, c_quizzes_procedure)
+        save_additional_results(n_epoch, model, models, c_quizzes_procedure)
 
     ######################################################################
 
index 1fe2e94..0bdaaec 100755 (executable)
@@ -294,6 +294,7 @@ class QuizMachine:
         struct,
         mask_loss,
         mask_noise=None,
+        temperature=1.0,
         device=None,
     ):
         if device is None:
@@ -323,7 +324,7 @@ class QuizMachine:
                 quiz_mask_loss = self.make_quiz_mask(
                     input, struct=struct, mask=mask_loss
                 )
-                output = model(mygpt.BracketedSequence(input)).x
+                output = model(mygpt.BracketedSequence(input)).x / temperature
                 l[...] = (
                     -F.cross_entropy(output.transpose(1, 2), input, reduction="none")
                     * quiz_mask_loss