From: François Fleuret Date: Sat, 20 Jul 2024 22:45:18 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ece66035329cf86a322560779672d84652dd2a12;p=culture.git Update. --- diff --git a/main.py b/main.py index 7588a50..653f5f5 100755 --- a/main.py +++ b/main.py @@ -472,11 +472,19 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 temperature_cold=args.temperature_cold, ) - recorded_too_simple.append( - keep_good_quizzes(models, c_quizzes, required_nb_failures=0) + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + + nc = quiz_machine.solution_nb_correct(models, c_quizzes) + + count_nc = tuple( + n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0) ) - c_quizzes = keep_good_quizzes(models, c_quizzes) + log_string(f"nb_correct {count_nc}") + + recorded_too_simple.append(c_quizzes[nc == len(models)]) + + c_quizzes = c_quizzes[nc == len(models) - 1] nb_validated[model_for_generation.id] += c_quizzes.size(0) total_nb_validated = nb_validated.sum().item() @@ -517,7 +525,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### # save images - vq = validated_quizzes[:128] + vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}" @@ -525,7 +533,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 args.result_dir, prefix, vq, show_part_to_predict=False ) - vq = too_simple_quizzes[:128] + vq = too_simple_quizzes if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple" @@ -642,6 +650,11 @@ if args.dirty_debug: ###################################################################### for n_epoch in range(current_epoch, args.nb_epochs): + state = {"current_epoch": n_epoch} + filename = "state.pth" + torch.save(state, os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + log_string(f"--- epoch {n_epoch} ----------------------------------------") cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) @@ -700,11 +713,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) log_string(f"wrote {filename}") - state = {"current_epoch": n_epoch} - filename = "state.pth" - torch.save(state, os.path.join(args.result_dir, filename)) - log_string(f"wrote {filename}") - # Renew the training samples for model in weakest_models: