From eecddf4b449e4f06d2aabcc84b789d5b75730810 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 18 Aug 2024 17:39:43 +0200 Subject: [PATCH] Update. --- main.py | 23 ++++++++++++++++------- quiz_machine.py | 15 ++++++++++++++- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index db16214..046514d 100755 --- a/main.py +++ b/main.py @@ -61,6 +61,8 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) +parser.add_argument("--c_quiz_multiplier", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--lambda_H", type=float, default=0.0) @@ -389,7 +391,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 full_input, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples, model.train_c_quiz_bags + args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier ) src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) @@ -900,10 +902,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### if args.nb_new_c_quizzes_for_train is None: - args.nb_new_c_quizzes_for_train = args.nb_train_samples // 1000 + args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250 if args.nb_new_c_quizzes_for_test is None: - args.nb_new_c_quizzes_for_test = args.nb_test_samples // 1000 + args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250 log_string( f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" @@ -1126,8 +1128,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"current_best_test_accuracies {cta}") ################################################## - # If all the models are good enough, generate new quizzes and - # re-compute the test errors for model in models: if model.test_accuracy >= args.accuracy_to_make_c_quizzes: @@ -1136,7 +1136,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) model.best_dict = copy.deepcopy(model.state_dict()) model.best_test_accuracy = model.test_accuracy - model.test_accuracy = 0.0 # we restart if total_time_generating_c_quizzes == 0: @@ -1167,7 +1166,17 @@ for n_epoch in range(current_epoch, args.nb_epochs): # Select, improve, and eval the worst model(s) if total_time_training_models <= total_time_generating_c_quizzes: - ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) + ranked_models = sorted( + models, + # This ugly recipe will pick the worst if there some below + # args.accuracy_to_make_c_quizzes or one at random if they + # are all above + key=lambda m: float( + m.test_accuracy + if m.test_accuracy < args.accuracy_to_make_c_quizzes + else args.accuracy_to_make_c_quizzes + torch.rand(1).item() + ), + ) weakest_models = ranked_models[: len(gpus)] diff --git a/quiz_machine.py b/quiz_machine.py index 18136e8..a0b007a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -140,10 +140,23 @@ class QuizMachine: ###################################################################### - def data_input(self, nb_samples, c_quiz_bags): + def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1): if len(c_quiz_bags) > 0: c_quizzes = torch.cat(c_quiz_bags, dim=0) + if c_quiz_multiplier > 1: + n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) + body = c_quizzes.repeat(n, 1) + if n < c_quiz_multiplier: + tail = c_quizzes[ + torch.randperm(c_quizzes.size(0))[ + : nb_samples // 2 - body.size(0) + ] + ] + c_quizzes = torch.cat([body, tail], dim=0) + else: + c_quizzes = body + if c_quizzes.size(0) > nb_samples // 2: i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] c_quizzes = c_quizzes[i] -- 2.39.5