From: François Fleuret Date: Tue, 16 Jul 2024 18:47:42 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=51f2897a7eb14ae72ee7eee788d876915ead4370;p=culture.git Update. --- diff --git a/main.py b/main.py index 2b71950..41efc86 100755 --- a/main.py +++ b/main.py @@ -410,12 +410,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 start_time = time.perf_counter() - nb_validated = torch.zeros(len(models)) + nb_validated = torch.zeros(len(models), dtype=torch.int64) while nb_validated.sum() < nb_to_create: # We balance the number of quizzes per model - model_for_generation = models[nb_validated.argmin()] + model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0] + print(nb_validated, "using", model_for_generation.id) c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, @@ -426,7 +427,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = keep_good_quizzes(models, c_quizzes) - nb_validated[model.id] += c_quizzes.size(0) + nb_validated[model_for_generation.id] += c_quizzes.size(0) total_nb_validated = nb_validated.sum().item() recorded.append(c_quizzes) @@ -442,7 +443,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration}/h)" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration:0.1f}/h)" ) validated_quizzes = torch.cat(recorded, dim=0)