From: François Fleuret Date: Mon, 19 Aug 2024 20:19:46 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=734fd15f40e45f4e560ee64a777a2759213e99e0;p=culture.git Update. --- diff --git a/main.py b/main.py index cd1e10f..901e91c 100755 --- a/main.py +++ b/main.py @@ -456,9 +456,9 @@ def model_modifier_cold(model): c_quizzes_procedure = [ - (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), + # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), + # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), + (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot), # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), ] @@ -562,8 +562,9 @@ def create_c_quizzes( train_c_quiz_bags, nb_for_test, test_c_quiz_bags, + local_device=main_device, ): - nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) + nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate start_time = time.perf_counter() @@ -600,10 +601,9 @@ def create_c_quizzes( mask=(0, 0, 0, 1), ) - keep = ( - model_proba_solutions(main_model, main_solution) - < args.proba_not_understands - ) + main_probas = model_proba_solutions(main_model, main_solution) + log_string(f"main_probas {main_probas}") + keep = main_probas < args.proba_not_understands c_quizzes = c_quizzes[keep] # If there are some quizzes that the main model cannot solve, @@ -624,6 +624,7 @@ def create_c_quizzes( ) probas = model_proba_solutions(model, solution) + log_string(f"probas {probas}") keep = probas >= c_quizzes_proba c_quizzes = solution[keep] c_quizzes_proba[keep] = probas[keep] @@ -652,19 +653,9 @@ def create_c_quizzes( # Save some images c_quizzes = torch.cat(recorded, dim=0) - - l = [ - model_proba_solutions(model, c_quizzes) for model in [main_model] + other_models - ] - probas = torch.cat([x[:, None] for x in l], dim=1) - comments = [] - for l in probas: - comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - - filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, filename, c_quizzes[:128], comments=comments - ) + n = (c_quizzes.size(0) * nb_for_train) // (nb_for_train + nb_for_test) + train_c_quiz_bags.append(c_quizzes[:n]) + test_c_quiz_bags.append(c_quizzes[n:]) log_string( f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in train_c_quiz_bags ])} test {sum([q.size(0) for q in test_c_quiz_bags ])}" @@ -750,8 +741,8 @@ if args.resume: state = torch.load(os.path.join(args.result_dir, filename)) log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] - train_c_quiz_bags = d["train_c_quiz_bags"] - test_c_quiz_bags = d["test_c_quiz_bags"] + train_c_quiz_bags = state["train_c_quiz_bags"] + test_c_quiz_bags = state["test_c_quiz_bags"] except FileNotFoundError: log_string(f"cannot find {filename}") pass @@ -865,7 +856,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"--- epoch {n_epoch} ----------------------------------------") cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models]) - log_string(f"current_test_accuracies {cta}") + log_string(f"test_accuracies {cta}") ################################################## @@ -880,6 +871,17 @@ for n_epoch in range(current_epoch, args.nb_epochs): test_c_quiz_bags=test_c_quiz_bags, ) + c_quizzes = train_c_quiz_bags[-128:] + l = [model_proba_solutions(model, c_quizzes) for model in models] + probas = torch.cat([x[:, None] for x in l], dim=1) + comments = [] + for l in probas: + comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + filename = f"culture_c_quiz_{n_epoch:04d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, c_quizzes, comments=comments + ) + for model in models: new_model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -897,13 +899,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): ################################################## # Select, improve, and eval the worst model(s) - ranked_models = sorted( - models, - # This ugly recipe will pick the worst if there some below - # args.accuracy_to_make_c_quizzes or one at random if they - # are all above - key=lambda m: float(m.test_accuracy), - ) + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -925,8 +921,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): for t in threads: t.join() - for model in weakest_models: - save_additional_results(n_epoch, model, models, c_quizzes_procedure) + # for model in weakest_models: + # save_additional_results(n_epoch, model, models, c_quizzes_procedure) # Save the models to disk