From b4fab2b79c9569179c023e3011df4f58ddf64bdb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 09:11:36 +0200 Subject: [PATCH] Update. --- main.py | 52 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 4488a70..5f80fb5 100755 --- a/main.py +++ b/main.py @@ -415,17 +415,22 @@ def batch_prediction_imt(input, fraction_with_hints=0.0): return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -def predict(model, imt_set, local_device=main_device): +def predict(model, imt_set, local_device=main_device, desc="predict"): model.eval().to(local_device) record = [] - for imt in tqdm.tqdm( - imt_set.split(args.physical_batch_size), - dynamic_ncols=True, - desc="predict", - total=imt_set.size(0) // args.physical_batch_size, - ): + src = imt_set.split(args.physical_batch_size) + + if desc is not None: + src = tqdm.tqdm( + src, + dynamic_ncols=True, + desc=desc, + total=imt_set.size(0) // args.physical_batch_size, + ) + + for imt in src: # some paranoia imt = imt.clone() imt[:, 0] = imt[:, 0] * (1 - imt[:, 1]) @@ -452,7 +457,7 @@ def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1 ) - result = predict(model, imt_set, local_device=local_device) + result = predict(model, imt_set, local_device=local_device, desc=None) result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) return result @@ -491,21 +496,26 @@ def prioritized_rand(low): return y -def generate(model, nb, local_device=main_device): +def generate(model, nb, local_device=main_device, desc="generate"): model.eval().to(local_device) all_input = quiz_machine.pure_noise(nb, local_device) all_masks = all_input.new_full(all_input.size(), 1) - for input, masks in tqdm.tqdm( - zip( - all_input.split(args.physical_batch_size), - all_masks.split(args.physical_batch_size), - ), - dynamic_ncols=True, - desc="generate", - total=all_input.size(0) // args.physical_batch_size, - ): + src = zip( + all_input.split(args.physical_batch_size), + all_masks.split(args.physical_batch_size), + ) + + if desc is not None: + src = tqdm.tqdm( + src, + dynamic_ncols=True, + desc="generate", + total=all_input.size(0) // args.physical_batch_size, + ) + + for input, masks in src: changed = True for it in range(args.diffusion_nb_iterations): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): @@ -877,12 +887,14 @@ def generate_c_quizzes(models, nb, local_device=main_device): generator_id = model.id c_quizzes = generate( - moel=copy_for_inference(model), + model=model, nb=args.physical_batch_size, local_device=local_device, + desc=None, ) nb_correct, nb_wrong = 0, 0 + for i, model in enumerate(models): model = copy.deepcopy(model).to(local_device).eval() result = predict_full(model, c_quizzes, local_device=local_device) @@ -897,6 +909,8 @@ def generate_c_quizzes(models, nb, local_device=main_device): nb_validated += to_keep.long().sum() record.append(c_quizzes[to_keep]) + log_string(f"generate_c_quizzes {nb_validated}") + return torch.cat(record) -- 2.39.5