From 9671e3f5eb1d7f5124ee3f998a1ae1450fcec5c5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 2 Sep 2024 17:26:39 +0200 Subject: [PATCH] Update. --- main.py | 44 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index b48d2a8..b6aa328 100755 --- a/main.py +++ b/main.py @@ -1070,7 +1070,7 @@ def ae_generate(model, input, mask_generate, nb_iterations_max=50): ###################################################################### -def model_ae_proba_solutions(model, input): +def model_ae_proba_solutions(model, input, log_proba=False): record = [] for q in input.split(args.batch_size): @@ -1089,7 +1089,10 @@ def model_ae_proba_solutions(model, input): loss = torch.cat(record, dim=0) - return (-loss).exp() + if log_proba: + return -loss + else: + return (-loss).exp() nb_diffusion_iterations = 25 @@ -1351,6 +1354,34 @@ def c_quiz_criterion_some(probas): ) +def save_badness_statistics( + n_epoch, models, c_quizzes, suffix=None, local_device=main_device +): + for model in models: + models.eval().to(local_device) + c_quizzes = c_quizzes.to(local_device) + with torch.autograd.no_grad(): + log_probas = sum( + [model_ae_proba_solutions(model, c_quizzes) for model in models] + ) + i = log_probas.sort().values + + suffix = "" if suffix is None else "_" + suffix + + filename = f"culture_badness_{n_epoch:04d}{suffix}.png" + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=quizzes[i[:128]], + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + comments=comments, + delta=True, + nrow=8, + ) + + def generate_ae_c_quizzes(models, local_device=main_device): criteria = [ # c_quiz_criterion_only_one, @@ -1493,6 +1524,7 @@ if args.resume: state = torch.load(os.path.join(args.result_dir, filename)) log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] + c_quizzes = state["c_quizzes"] # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] # total_time_training_models = state["total_time_training_models"] # common_c_quiz_bags = state["common_c_quiz_bags"] @@ -1520,10 +1552,12 @@ for n_epoch in range(current_epoch, args.nb_epochs): state = { "current_epoch": n_epoch, + "c_quizzes": c_quizzes, # "total_time_generating_c_quizzes": total_time_generating_c_quizzes, # "total_time_training_models": total_time_training_models, # "common_c_quiz_bags": common_c_quiz_bags, } + filename = "state.pth" torch.save(state, os.path.join(args.result_dir, filename)) log_string(f"wrote {filename}") @@ -1541,10 +1575,12 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"{time_train=} {time_c_quizzes=}") if ( - n_epoch >= 200 - and min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes + min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes and time_train >= time_c_quizzes ): + if c_quizzes is not None: + save_badness_statistics(models, c_quizzes) + last_n_epoch_c_quizzes = n_epoch start_time = time.perf_counter() c_quizzes = generate_ae_c_quizzes(models, local_device=main_device) -- 2.39.5