From: François Fleuret Date: Tue, 13 Aug 2024 21:06:40 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ce4bf29737543de0166b6bcc0e5f7406d986cec8;p=culture.git Update. --- diff --git a/main.py b/main.py index 8e06bb2..0b9a86e 100755 --- a/main.py +++ b/main.py @@ -446,7 +446,7 @@ c_quizzes_procedure = [ (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), ] ###################################################################### @@ -838,7 +838,15 @@ if args.test == "func": output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x dist = torch.distributions.categorical.Categorical(logits=output) - input[:, 3 * L :] = dist.sample() + input[:, 3 * L + 1 :] = dist.sample()[:, 1:] + + problem.save_quizzes_as_image( + args.result_dir, + f"thinker_prediction_{n_epoch:04d}.png", + quizzes=input, + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + ) ######################################################################