From 9f4dead1dda719e24132ffd8689d558843f0b1c5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 12:38:26 +0200 Subject: [PATCH] Update. --- attae.py | 4 ++-- grids.py | 18 ++++++++++------ main.py | 57 ++++++++++++++++++++++++++++++++----------------- quiz_machine.py | 6 ++++++ 4 files changed, 57 insertions(+), 28 deletions(-) diff --git a/attae.py b/attae.py index a9bdeba..1e5e122 100755 --- a/attae.py +++ b/attae.py @@ -180,8 +180,8 @@ class FunctionalAttentionAE(AttentionAE): a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) n = self.nb_work_tokens s = (q.size(2) - n) // 2 - a[:, :, n + 0 * s : n + 1 * s, n + 0 * s : n + 1 * s] = float("-inf") - a[:, :, n + 1 * s : n + 2 * s, n + 1 * s : n + 2 * s] = float("-inf") + a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf") + a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf") a = a.softmax(dim=3) y = torch.einsum("nhts,nhsd->nhtd", a, v) return y diff --git a/grids.py b/grids.py index 23a3d12..4254b32 100755 --- a/grids.py +++ b/grids.py @@ -134,24 +134,28 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations): class Grids(problem.Problem): - # grid_gray=64 + grid_gray = 64 + thickness = 1 + background_gray = 255 + + # grid_gray=240 # thickness=1 - # background_gray=255 + # background_gray=240 - grid_gray = 255 - thickness = 0 - background_gray = grid_gray + # grid_gray = 255 + # thickness = 0 + # background_gray = 240 named_colors = [ ("white", [background_gray, background_gray, background_gray]), # ("white", [224, 224, 224]), ("red", [255, 0, 0]), - ("green", [0, 192, 0]), + ("green", [0, 160, 0]), ("blue", [0, 0, 255]), ("yellow", [255, 224, 0]), ("cyan", [0, 255, 255]), ("violet", [224, 128, 255]), - ("lightgreen", [192, 255, 192]), + ("lightgreen", [160, 255, 160]), ("brown", [165, 42, 42]), ("lightblue", [192, 192, 255]), ("gray", [128, 128, 128]), diff --git a/main.py b/main.py index d903693..750d1b1 100755 --- a/main.py +++ b/main.py @@ -54,7 +54,7 @@ parser.add_argument("--nb_train_samples", type=int, default=50000) parser.add_argument("--nb_test_samples", type=int, default=1000) -parser.add_argument("--nb_c_quizzes", type=int, default=10000) +parser.add_argument("--nb_c_quizzes", type=int, default=5000) parser.add_argument("--c_quiz_multiplier", type=int, default=1) @@ -64,7 +64,7 @@ parser.add_argument("--nb_have_to_be_correct", type=int, default=3) parser.add_argument("--nb_have_to_be_wrong", type=int, default=1) -parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=10) +parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5) # ---------------------------------- @@ -198,7 +198,8 @@ if args.seed >= 0: def log_string(s): - """print the given string prefixed with a time stamps, and log it into log_file is not None""" + """print the given string prefixed with a time stamps, and log it + into log_file is not None""" t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) @@ -301,10 +302,9 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### -# If we need to move an optimizer to a different device - def optimizer_to(optim, device): + """Move the optimizer optim to the device""" for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): @@ -322,11 +322,12 @@ def optimizer_to(optim, device): ###################################################################### -# Make args.nb_hints holes in the mask and copy the corresponding cell -# values from the target to the input - - def add_hints_imt(imt_set): + """Set every component of the mask to zero with probability + args.proba_hint, and for each component set to zero, copy the + corresponding value from the target into the input + + """ input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] # h = torch.rand(masks.size(), device=masks.device) - masks # t = h.sort(dim=1).values[:, args.nb_hints, None] @@ -339,11 +340,9 @@ def add_hints_imt(imt_set): return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -# Make pixels from the available input (mask=0) noise with probability -# args.proba_prompt_noise - - def add_noise_imt(imt_set): + """Replace every component of the input by a random value with + probability args.proba_prompt_noise.""" input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] noise = quiz_machine.pure_noise(input.size(0), input.device) change = (1 - masks) * ( @@ -633,8 +632,8 @@ import attae models = [] for i in range(args.nb_models): - model = attae.FunctionalAttentionAE( - # model = attae.AttentionAE( + # model = attae.FunctionalAttentionAE( + model = attae.AttentionAE( vocabulary_size=vocabulary_size * 2, dim_model=args.dim_model, dim_keys=args.dim_keys, @@ -655,7 +654,7 @@ for i in range(args.nb_models): ###################################################################### -def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): +def evaluate_quizzes(quizzes, models, with_perturbations, local_device): nb_correct, nb_wrong = 0, 0 for model in models: @@ -663,7 +662,7 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): result = predict_full( model=model, input=quizzes, - with_perturbations=True, + with_perturbations=with_perturbations, local_device=local_device, ) nb_mistakes = (result != quizzes).long().sum(dim=1) @@ -680,6 +679,26 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): ###################################################################### +def remove_old_problematic(c_quizzes, models, nb_to_remove, local_device): + nb_removed = 0 + for input in c_quizzes.split(args.eval_batch_size): + _, nb_correct, nb_wrong = evaluate_quizzes( + quizzes=input, + models=models, + with_perturbations=False, + local_device=local_device, + ) + + to_remove = nb_wrong > 0 + nb_removed += to_remove.long().sum() + + if nb_removed >= nb_to_remove: + break + + +###################################################################### + + def identity_quizzes(quizzes): quizzes = quizzes.reshape(quizzes.size(0), 4, -1) return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & ( @@ -714,7 +733,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): to_keep, nb_correct, nb_wrong = evaluate_quizzes( quizzes=c_quizzes, models=models, - fraction_with_hints=1.0, + with_perturbations=True, local_device=local_device, ) @@ -760,7 +779,7 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device): to_keep, nb_correct, nb_wrong = evaluate_quizzes( quizzes=c_quizzes, models=models, - fraction_with_hints=0, + with_perturbations=False, local_device=local_device, ) diff --git a/quiz_machine.py b/quiz_machine.py index e2f6d3b..72f1d16 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -206,6 +206,8 @@ class QuizMachine: quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape( quizzes.size(0), -1 ) + nb_w_quizzes = quizzes.size(0) + nb_c_quizzes = 0 else: if c_quiz_multiplier > 1: n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) @@ -229,10 +231,14 @@ class QuizMachine: w_quizzes.size(0), -1 ) quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + nb_w_quizzes = w_quizzes.size(0) + nb_c_quizzes = c_quizzes.size(0) i = torch.randperm(quizzes.size(0), device=quizzes.device) quizzes = quizzes[i].contiguous() + logger(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}") + return quizzes ###################################################################### -- 2.39.5