From dc2920921f3a98a9cb2bea75f83c8286857198d9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 15:26:27 +0200 Subject: [PATCH] Update. --- main.py | 65 +++++++++++++++++++++++++------------------------ quiz_machine.py | 10 +++++--- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/main.py b/main.py index e38cbc0..edc366a 100755 --- a/main.py +++ b/main.py @@ -373,12 +373,20 @@ def masked_cross_entropy(output, targets, masks): ###################################################################### +def add_hints_(imt_set): + input, masks, targets = imt_set + h = torch.rand(masks.size(), device=masks.device) - masks + t = h.sort(dim=1).values[:, args.nb_hints, None] + mask_hints = (h < t).long() + masks[...] = (1 - mask_hints) * masks + input[...] = (1 - mask_hints) * input + mask_hints * targets + + def add_hints(masks, fraction_with_hints): if fraction_with_hints > 0: - h = torch.rand(masks.size(), device=masks.device) * masks - mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints - v = torch.rand(masks.size(0), device=masks.device)[:, None] - mask_hints = mask_hints * (v < fraction_with_hints).long() + h = torch.rand(masks.size(), device=masks.device) - masks + t = h.sort(dim=1).values[:, args.nb_hints, None] + mask_hints = (h < t).long() return (1 - mask_hints) * masks else: return masks @@ -387,19 +395,18 @@ def add_hints(masks, fraction_with_hints): # IMT for input / masks / target -def batch_prediction_imt(input, fraction_with_hints=0.0): +def batch_for_prediction_imt(input): nb = input.size(0) masks = input.new_zeros(input.size()) u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) masks.view(nb, 4, -1)[...] = u[:, :, None] - masks = add_hints(masks, fraction_with_hints) targets = input input = (1 - masks) * targets return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -def predict(model, imt_set, local_device=main_device, desc="predict"): +def ae_predict(model, imt_set, local_device=main_device, desc="predict"): model.eval().to(local_device) record = [] @@ -428,20 +435,17 @@ def predict(model, imt_set, local_device=main_device, desc="predict"): return torch.cat(record) -def predict_full(model, input, fraction_with_hints, local_device=main_device): +def predict_full(model, input, local_device=main_device): input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) nb = input.size(0) masks = input.new_zeros(input.size()) u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4) masks.view(nb, 4, -1)[...] = u[:, :, None] - masks_with_hints = add_hints(masks, fraction_with_hints) targets = input - input = (1 - masks_with_hints) * targets - imt_set = torch.cat( - [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1 - ) + input = (1 - masks) * targets + imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) - result = predict(model, imt_set, local_device=local_device, desc=None) + result = ae_predict(model, imt_set, local_device=local_device, desc=None) result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) return result @@ -450,7 +454,7 @@ def predict_full(model, input, fraction_with_hints, local_device=main_device): ###################################################################### -def batch_generation_imt(input): +def batch_for_generation_imt(input): nb = input.size(0) probs_iterations = 0.1 ** torch.linspace( 0, 1, args.diffusion_nb_iterations, device=input.device @@ -480,7 +484,7 @@ def prioritized_rand(low): return y -def generate(model, nb, local_device=main_device, desc="generate"): +def ae_generate(model, nb, local_device=main_device, desc="generate"): model.eval().to(local_device) all_input = quiz_machine.pure_noise(nb, local_device) @@ -533,8 +537,8 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): imt_set = torch.cat( [ - batch_prediction_imt(q1, fraction_with_hints=0.5), - batch_generation_imt(q2), + batch_for_prediction_imt(q1), + batch_for_generation_imt(q2), ] ) @@ -597,13 +601,13 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result ) - # Save some images of the prediction results (one grid at random) + # Save some images of the prediction results quizzes = quiz_machine.quiz_set( args.nb_test_samples, c_quizzes, args.c_quiz_multiplier ) - imt_set = batch_prediction_imt(quizzes.to(local_device)) - result = predict(model, imt_set, local_device=local_device).to("cpu") + imt_set = batch_for_prediction_imt(quizzes.to(local_device)) + result = ae_predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") correct = (quizzes == result).min(dim=1).values.long() @@ -631,7 +635,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): # Save some images of the ex nihilo generation of the four grids - result = generate(model, 150, local_device=local_device).to("cpu") + result = ae_generate(model, 150, local_device=local_device).to("cpu") quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"culture_generation_{n_epoch}_{model.id}.png", @@ -695,21 +699,21 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): ###################################################################### -def generate_c_quizzes(models, nb, local_device=main_device): +def generate_c_quizzes(models, nb_to_generate, local_device=main_device): record = [] nb_validated = 0 start_time = time.perf_counter() last_log = -1 - while nb_validated < nb: + while nb_validated < nb_to_generate: # Generate new quizzes model = models[torch.randint(len(models), (1,)).item()] model = copy.deepcopy(model).to(local_device).eval() generator_id = model.id - c_quizzes = generate( + c_quizzes = ae_generate( model=model, nb=args.physical_batch_size, local_device=local_device, @@ -736,8 +740,8 @@ def generate_c_quizzes(models, nb, local_device=main_device): if last_log < 0 or duration > last_log + 10: last_log = duration if nb_validated > 0: - if nb_validated < nb: - d = (nb - nb_validated) * duration / nb_validated + if nb_validated < nb_to_generate: + d = (nb_to_generate - nb_validated) * duration / nb_validated e = ( datetime.datetime.now() + datetime.timedelta(seconds=d) ).strftime("%a %H:%M") @@ -754,7 +758,7 @@ def generate_c_quizzes(models, nb, local_device=main_device): duration = time.perf_counter() - start_time - log_string(f"generate_c_quizz_speed {int(3600 * nb / duration)}/h") + log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h") return torch.cat(record).to("cpu") @@ -848,7 +852,7 @@ if args.quizzes is not None: mask_generate = quiz_machine.make_quiz_mask( quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - result = generate( + result = ae_generate( model, (1 - mask_generate) * quizzes, mask_generate, @@ -874,8 +878,6 @@ if args.quizzes is not None: ###################################################################### -last_n_epoch_c_quizzes = 0 - c_quizzes = None time_c_quizzes = 0 @@ -963,7 +965,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): if c_quizzes is None: save_models(models, "naive") - last_n_epoch_c_quizzes = n_epoch nb_gpus = len(gpus) nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus diff --git a/quiz_machine.py b/quiz_machine.py index 594b5ca..e2f6d3b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -203,6 +203,9 @@ class QuizMachine: def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1): if c_quizzes is None: quizzes = self.problem.generate_w_quizzes(nb_samples) + quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape( + quizzes.size(0), -1 + ) else: if c_quiz_multiplier > 1: n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) @@ -222,15 +225,14 @@ class QuizMachine: c_quizzes = c_quizzes[i] w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) + w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape( + w_quizzes.size(0), -1 + ) quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) i = torch.randperm(quizzes.size(0), device=quizzes.device) quizzes = quizzes[i].contiguous() - quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape( - quizzes.size(0), -1 - ) - return quizzes ###################################################################### -- 2.39.5