From: François Fleuret Date: Sun, 4 Aug 2024 05:02:16 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=54c7af40fc3afcc4915147375952a46e80194f40;p=culture.git Update. --- diff --git a/main.py b/main.py index 9a8bd43..8f3568f 100755 --- a/main.py +++ b/main.py @@ -477,110 +477,110 @@ c_quizzes_procedure = [ ###################################################################### -def save_additional_results(models, science_w_quizzes): +def save_additional_results(model, models, science_w_quizzes): # Save generated quizzes with the successive steps - for model in models: - recorder = [] + recorder = [] - c_quizzes = quiz_machine.generate_c_quizzes( - 64, - model_for_generation=model, - procedure=c_quizzes_procedure, - recorder=recorder, - ) + c_quizzes = quiz_machine.generate_c_quizzes( + 64, + model_for_generation=model, + procedure=c_quizzes_procedure, + recorder=recorder, + ) - ## + ## - probas = 0 + probas = 0 - for a in range(args.nb_averaging_rounds): - # This is nb_quizzes x nb_models + for a in range(args.nb_averaging_rounds): + # This is nb_quizzes x nb_models - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) - probas += seq_logproba.exp() + probas += seq_logproba.exp() - probas /= args.nb_averaging_rounds + probas /= args.nb_averaging_rounds - comments = [] + comments = [] - for l in seq_logproba: - comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) + for l in seq_logproba: + comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) - ## + ## - c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1) - predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1) - nb_steps = c_quizzes.size(1) - c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1)) - predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1)) + c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1) + predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1) + nb_steps = c_quizzes.size(1) + c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1)) + predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1)) - # We have comments only for the final quiz, not the successive - # steps, so we have to add nb_steps-1 empty comments + # We have comments only for the final quiz, not the successive + # steps, so we have to add nb_steps-1 empty comments - steps_comments = [] - for c in comments: - steps_comments += [""] * (nb_steps - 1) + [c] + steps_comments = [] + for c in comments: + steps_comments += [""] * (nb_steps - 1) + [c] - filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=c_quizzes, - predicted_parts=predicted_parts, - comments=steps_comments, - nrow=nb_steps * 2, # two quiz per row - ) - log_string(f"wrote {filename}") + filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png" + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=c_quizzes, + predicted_parts=predicted_parts, + comments=steps_comments, + nrow=nb_steps * 2, # two quiz per row + ) + + log_string(f"wrote {filename}") ###################################################################### if science_w_quizzes is not None: - for model in models: - struct = ("A", "f_A", "B", "f_B") - mask = (0, 0, 0, 1) - result, correct = quiz_machine.predict( - model=model, - quizzes=science_w_quizzes.to(main_device), - struct=struct, - mask=mask, - ) + struct = ("A", "f_A", "B", "f_B") + mask = (0, 0, 0, 1) + result, correct = quiz_machine.predict( + model=model, + quizzes=science_w_quizzes.to(main_device), + struct=struct, + mask=mask, + ) - predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand( - correct.size(0), -1 - ) - correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long() + predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand( + correct.size(0), -1 + ) + correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long() - nb_correct = (correct == 1).long().sum() - nb_total = (correct != 0).long().sum() + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() - log_string( - f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" - ) + log_string( + f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" + ) - i = correct == 1 - j = correct != 1 + i = correct == 1 + j = correct != 1 - result = torch.cat([result[i], result[j]], dim=0) - correct = torch.cat([correct[i], correct[j]], dim=0) - correct_parts = predicted_parts * correct[:, None] + result = torch.cat([result[i], result[j]], dim=0) + correct = torch.cat([correct[i], correct[j]], dim=0) + correct_parts = predicted_parts * correct[:, None] - result = result[:128] - predicted_parts = predicted_parts[:128] - correct_parts = correct_parts[:128] + result = result[:128] + predicted_parts = predicted_parts[:128] + correct_parts = correct_parts[:128] - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - f"culture_science_{n_epoch:04d}_{model.id:02d}.png", - quizzes=result, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - ) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"culture_science_{n_epoch:04d}_{model.id:02d}.png", + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, + ) ###################################################################### @@ -1310,7 +1310,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) log_string(f"wrote {filename}") - save_additional_results(weakest_models, science_w_quizzes) + for model in weakest_models: + save_additional_results(model, models, science_w_quizzes) ######################################################################