From: François Fleuret Date: Wed, 31 Jul 2024 19:02:38 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1fcb346953e6a7999342339afe0c00ffa92834af;p=culture.git Update. --- diff --git a/main.py b/main.py index 4903585..cce747a 100755 --- a/main.py +++ b/main.py @@ -682,7 +682,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # 2->proba>=proba_understands and 1 otherwise. -def generate_c_quizz_with_generator(generator, quiz_machine, nb): +def generate_c_quizzes_with_generator(generator, quiz_machine, nb): generator.to(main_device) struct = ("A", "f_A", "B", "f_B") @@ -695,13 +695,13 @@ def generate_c_quizz_with_generator(generator, quiz_machine, nb): num_classes=args.nb_gpts, ) - prolog_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i) - prolog_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prolog_c_quizzes.size(1)) + prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i) + prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1)) - prologued_c_quizzes = torch.cat([prolog_c_quizzes, c_quizzes], dim=1).to( + prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to( main_device ) - prologued_ar_mask = torch.cat([prolog_ar_mask, ar_mask], dim=1).to(main_device) + prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device) seq_logproba = torch.zeros( prologued_c_quizzes.size(0), device=prologued_c_quizzes.device @@ -729,7 +729,9 @@ def generate_c_quizz_with_generator(generator, quiz_machine, nb): prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long() ) - return prologued_c_quizzes[:, prolog_c_quizzes.size(1) :].to("cpu") + c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :] + + return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu") def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0): @@ -746,7 +748,7 @@ def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1. ) else: # Or we use the generator itself to generate them - c_quizzes = generate_c_quizz_with_generator( + c_quizzes, _ = generate_c_quizzes_with_generator( generator, quiz_machine, args.batch_size ) @@ -770,13 +772,13 @@ def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1. u2 = probas >= args.proba_understands u1 = (u0 | u2) == False - prolog = ( + prologs = ( (u0.long() * token_prolog_0) + (u1.long() * token_prolog_1) + (u2.long() * token_prolog_2) ) - prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1) + prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1) # nb_u2 = u2.long().sum(dim=1) # nb_u0 = u0.long().sum(dim=1) @@ -1031,17 +1033,6 @@ if args.dirty_debug: ###################################################################### if args.test_generator: - filename = f"generator.pth" - - try: - d = torch.load(os.path.join(args.result_dir, filename)) - generator.load_state_dict(d[0]) - generator.main_test_accuracy = d[1] - log_string(f"successfully loaded {filename}") - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass - token_prolog_0 = vocabulary_size + 0 token_prolog_1 = vocabulary_size + 1 token_prolog_2 = vocabulary_size + 2 @@ -1060,6 +1051,17 @@ if args.test_generator: generator.main_test_accuracy = 0.0 + filename = f"generator.pth" + + try: + d = torch.load(os.path.join(args.result_dir, filename)) + generator.load_state_dict(d[0]) + generator.main_test_accuracy = d[1] + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + for n_epoch in range(args.nb_epochs): one_generator_epoch( generator, @@ -1076,7 +1078,7 @@ if args.test_generator: ) log_string(f"wrote {filename}") - c_quizzes = generate_c_quizz_with_generator( + c_quizzes, prologs = generate_c_quizzes_with_generator( generator, quiz_machine, args.batch_size ) @@ -1086,12 +1088,34 @@ if args.test_generator: models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) ) - print(seq_logproba.exp()) + probas = seq_logproba.exp() + + u0 = probas <= args.proba_not_understands + u2 = probas >= args.proba_understands + u1 = (u0 | u2) == False + + predicted_prologs = ( + (u0.long() * token_prolog_0) + + (u1.long() * token_prolog_1) + + (u2.long() * token_prolog_2) + ) comments = [] - for l in seq_logproba: - comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) + nb_errors = (predicted_prologs != prologs).long().sum() + nb_total = prologs.numel() + + log_string(f"generator_error {nb_errors} / {nb_total}") + + def readable(prologs): + return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2) + + for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)): + sa = "prolog " + " ".join( + [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)] + ) + sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa]) + comments.append(sa + "\n" + sp) filename = f"generator_batch_{n_epoch:04d}.png" quiz_machine.problem.save_quizzes_as_image(