From 9db1a0b7bcb3ffe931dbf847a58a6bd1111b2144 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 5 Sep 2024 22:44:02 +0200 Subject: [PATCH] Update. --- main.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 8e938db..f609fd8 100755 --- a/main.py +++ b/main.py @@ -881,6 +881,32 @@ def model_ae_argmax_nb_disagreements(model, input): return torch.cat(record, dim=0) +###################################################################### + + +def model_ae_argmax_predictions(model, input): + result = input.clone() + # result[...] = 0 + + for r, q in zip(result.split(args.batch_size), input.split(args.batch_size)): + for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + ) + targets, logits = targets_and_prediction( + model, q, mask_generate, prompt_noise=args.prompt_noise + ) + + predicted = logits.argmax(dim=-1) + + r[...] = (1 - mask_generate) * r + mask_generate * predicted + + return result + + +###################################################################### + + def degrade_input_to_generate(input, mask_generate, steps_nb_iterations): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device @@ -942,7 +968,9 @@ def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0): def run_ae_test( model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None ): - if prefix is not None: + if prefix is None: + prefix = "" + else: prefix = prefix + "_" with torch.autograd.no_grad(): @@ -1216,14 +1244,15 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): wanted_nb = nb nb_to_save = 256 + nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device) with torch.autograd.no_grad(): - records = [] + record_c_quizzes, record_agreements = [], [] last_log = -1 start_time = time.perf_counter() - while bag_len(records) < wanted_nb: + while nb_c_quizzes_per_model.min() < wanted_nb: model = copy_for_inference(models[torch.randint(len(models), (1,)).item()]) generator_id = model.id @@ -1242,7 +1271,8 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): # to_keep = c_quiz_criterion_two_good(probas) nb_disagreements = [] - for model in models: + for i, model in enumerate(models): + assert i == model.id # a bit of paranoia model = copy_for_inference(model) nb_disagreements.append( model_ae_argmax_nb_disagreements(model, c_quizzes).long()[ @@ -1252,15 +1282,18 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): nb_disagreements = torch.cat(nb_disagreements, dim=1) v = nb_disagreements.sort(dim=1).values - to_keep = (v[:, 1] == 0) & (v[:, -1] > 3) + to_keep = (v[:, 2] == 0) & (v[:, -1] >= 4) q = c_quizzes[to_keep] if q.size(0) > 0: - records.append(q) + record_c_quizzes.append(q) + a = (nb_disagreements == 0)[to_keep] + record_agreements.append(a) + nb_c_quizzes_per_model += a.long().sum(dim=0) duration = time.perf_counter() - start_time - nb_generated = bag_len(records) + nb_generated = nb_c_quizzes_per_model.min().item() if last_log < 0 or duration > last_log + 5: last_log = duration @@ -1276,17 +1309,33 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): e = "???" log_string( - f"nb_generated {bag_len(records)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" + f"nb_generated {bag_len(record_c_quizzes)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" ) duration = time.perf_counter() - start_time log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h") - c_quizzes = torch.cat(records, dim=0).unique(dim=0) + c_quizzes = torch.cat(record_c_quizzes, dim=0) + agreements = torch.cat(record_agreements, dim=0) subset_c_quizzes = c_quizzes[:nb_to_save] + # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # for model in models: + # model = copy_for_inference(model) + # prediction = model_ae_argmax_predictions(model, subset_c_quizzes) + # filename = f"prediction_c_quiz_{n_epoch:04d}_{model.id}.png" + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename, + # quizzes=prediction, + # nrow=8, + # ) + # log_string(f"wrote {filename}") + # exit(0) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + filename = f"culture_c_quiz_{n_epoch:04d}.png" # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record) @@ -1314,7 +1363,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): log_string(f"wrote {filename}") - return c_quizzes + return c_quizzes, agreements def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): @@ -1443,7 +1492,10 @@ for n_epoch in range(current_epoch, args.nb_epochs): time_c_quizzes = int(time.perf_counter() - start_time) - c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0) + c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0) + agreements = torch.cat([a.to(main_device) for _, a in records], dim=0) + + print(f"DEBUG {c_quizzes.size()=} {agreements.size()=}") # -------------------------------------------------------------------- @@ -1474,11 +1526,15 @@ for n_epoch in range(current_epoch, args.nb_epochs): for gpu, model in zip(gpus, weakest_models): log_string(f"training model {model.id} (accuracy {model.test_accuracy})") + if c_quizzes is None: + c_quizzes_for_this_model = None + else: + c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]] t = threading.Thread( target=one_ae_epoch, daemon=True, - args=(model, quiz_machine, n_epoch, c_quizzes, gpu), + args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu), ) threads.append(t) -- 2.39.5