From e8994543452e3cb885515cd5bbf97e67b854ae4b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 10:21:02 +0200 Subject: [PATCH] Update. --- main.py | 78 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/main.py b/main.py index 84224e9..86a3ae9 100755 --- a/main.py +++ b/main.py @@ -670,6 +670,26 @@ for i in range(args.nb_models): ###################################################################### +def evaluate_quizzes(c_quizzes, models, local_device): + nb_correct, nb_wrong = 0, 0 + + for model in models: + model = copy.deepcopy(model).to(local_device).eval() + result = predict_full(model, c_quizzes, local_device=local_device) + nb_mistakes = (result != c_quizzes).long().sum(dim=1) + nb_correct += (nb_mistakes == 0).long() + nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong + + to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( + nb_wrong >= args.nb_have_to_be_wrong + ) + + return to_keep, nb_correct, nb_wrong + + +###################################################################### + + def generate_c_quizzes(models, nb, local_device=main_device): record = [] nb_validated = 0 @@ -694,17 +714,8 @@ def generate_c_quizzes(models, nb, local_device=main_device): # Select the ones that are solved properly by some models and # not understood by others - nb_correct, nb_wrong = 0, 0 - - for i, model in enumerate(models): - model = copy.deepcopy(model).to(local_device).eval() - result = predict_full(model, c_quizzes, local_device=local_device) - nb_mistakes = (result != c_quizzes).long().sum(dim=1) - nb_correct += (nb_mistakes == 0).long() - nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong - - to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( - nb_wrong >= args.nb_have_to_be_wrong + to_keep, nb_correct, nb_wrong = evaluate_quizzes( + c_quizzes, models, local_device ) nb_validated += to_keep.long().sum().item() @@ -743,31 +754,19 @@ def generate_c_quizzes(models, nb, local_device=main_device): ###################################################################### -def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False): - l = [] - - c_quizzes = c_quizzes.to(main_device) - - with torch.autograd.no_grad(): - to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation( - models, - c_quizzes, - main_device, - nb_have_to_be_correct=args.nb_have_to_be_correct, - nb_have_to_be_wrong=0, - nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong, - nb_hints=None, - ) +def save_quiz_image( + models, c_quizzes, filename, solvable_only=False, local_device=main_device +): + c_quizzes = c_quizzes.to(local_device) - if solvable_only: - c_quizzes = c_quizzes[to_keep] - nb_correct = nb_correct[to_keep] - nb_wrong = nb_wrong[to_keep] + to_keep, nb_correct, nb_wrong = evaluate_quizzes(c_quizzes, models, local_device) - comments = [] + if solvable_only: + c_quizzes = c_quizzes[to_keep] + nb_correct = nb_correct[to_keep] + nb_wrong = nb_wrong[to_keep] - for c, w in zip(nb_correct, nb_wrong): - comments.append(f"nb_correct {c} nb_wrong {w}") + comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)] quiz_machine.problem.save_quizzes_as_image( args.result_dir, @@ -880,7 +879,10 @@ def multithread_execution(fun, arguments): records, threads = [], [] def threadable_fun(*args): - records.append(fun(*args)) + r = fun(*args) + if type(r) is not tuple: + r = (r,) + records.append(r) for args in arguments: # To get a different sequence between threads @@ -952,19 +954,19 @@ for n_epoch in range(current_epoch, args.nb_epochs): nb_gpus = len(gpus) nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus - c_quizzes = multithread_execution( + (c_quizzes,) = multithread_execution( generate_c_quizzes, [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], ) - save_c_quizzes_with_scores( + save_quiz_image( models, c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png", solvable_only=False, ) - save_c_quizzes_with_scores( + save_quiz_image( models, c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}_solvable.png", @@ -974,7 +976,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:] i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices - save_c_quizzes_with_scores( + save_quiz_image( models, c_quizzes[i][:256], f"culture_c_quiz_{n_epoch:04d}_solvable_high_delta.png", -- 2.39.5