From: François Fleuret Date: Wed, 21 Aug 2024 20:29:53 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=a435bbc5801f392b60ad4508be9a8265f1e525bb;p=culture.git Update. --- diff --git a/main.py b/main.py index 410ce7d..66b7ffd 100755 --- a/main.py +++ b/main.py @@ -1089,55 +1089,66 @@ for n_epoch in range(current_epoch, args.nb_epochs): if total_time_generating_c_quizzes == 0: total_time_training_models = 0 - if ( - min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes - and total_time_training_models >= total_time_generating_c_quizzes - ): - ###################################################################### - # Re-initalize if there are enough culture quizzes - + if min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: if args.reboot: - nb_c_quizzes_per_model = [ - sum([x.size(0) for x in model.train_c_quiz_bags]) for model in models - ] + for model in models: + model.current_dict = copy.deepcopy(model.state_dict()) + model.load_state_dict(model.gen_state_dict) + + while True: + record_new_c_quizzes( + models, + quiz_machine, + args.nb_new_c_quizzes_for_train, + args.nb_new_c_quizzes_for_test, + ) - p = tuple( - f"{(x*100)/args.nb_train_samples:.02f}%" for x in nb_c_quizzes_per_model - ) + nb_c_quizzes_per_model = [ + sum([x.size(0) for x in model.train_c_quiz_bags]) + for model in models + ] - log_string(f"nb_c_quizzes_per_model {p}") + p = tuple( + f"{(x*100)/args.nb_train_samples:.02f}%" + for x in nb_c_quizzes_per_model + ) - m = max(nb_c_quizzes_per_model) + log_string(f"nb_c_quizzes_per_model {p}") - if m >= args.nb_train_samples: - model = models[nb_c_quizzes_per_model.index(m)] - common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0)) - nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags]) - log_string( - f"rebooting the models with {nb_common_c_quizzes} culture quizzes" - ) + m = max(nb_c_quizzes_per_model) - models = create_models() - total_time_generating_c_quizzes = 0 - total_time_training_models = 0 + if m >= args.nb_train_samples: + break - for model in models: - model.current_dict = copy.deepcopy(model.state_dict()) - model.load_state_dict(model.gen_state_dict) + model = models[nb_c_quizzes_per_model.index(m)] + common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0)) + nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags]) + log_string( + f"rebooting the models with {nb_common_c_quizzes} culture quizzes" + ) - start_time = time.perf_counter() + models = create_models() + total_time_generating_c_quizzes = 0 + total_time_training_models = 0 - record_new_c_quizzes( - models, - quiz_machine, - args.nb_new_c_quizzes_for_train, - args.nb_new_c_quizzes_for_test, - ) + elif total_time_training_models >= total_time_generating_c_quizzes: + for model in models: + model.current_dict = copy.deepcopy(model.state_dict()) + model.load_state_dict(model.gen_state_dict) - total_time_generating_c_quizzes += time.perf_counter() - start_time + start_time = time.perf_counter() - for model in models: - model.load_state_dict(model.current_dict) + record_new_c_quizzes( + models, + quiz_machine, + args.nb_new_c_quizzes_for_train, + args.nb_new_c_quizzes_for_test, + ) + + total_time_generating_c_quizzes += time.perf_counter() - start_time + + for model in models: + model.load_state_dict(model.current_dict) ################################################## # Select, improve, and eval the worst model(s)