From 61e065fa623a7717855c1e6b8b530533661d9a54 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 8 Sep 2024 22:51:44 +0200 Subject: [PATCH] Update. --- attae.py | 25 ++++++++- main.py | 166 +++++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 153 insertions(+), 38 deletions(-) diff --git a/attae.py b/attae.py index e9e4bff..069772b 100755 --- a/attae.py +++ b/attae.py @@ -45,6 +45,21 @@ class WithResidual(nn.Module): ###################################################################### +class vanilla_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + + # y = flex_attention(q, k, v, score_mod=noop) + + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return y + + +vanilla_attention = torch.compille(vanilla_attention) + + class MHAttention(nn.Module): def __init__( self, @@ -72,10 +87,14 @@ class MHAttention(nn.Module): x_kv = x_q q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) - k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k) - v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v) + k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) + v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) + + def noop(score, b, h, q_idx, kv_idx): + return score - y = flex_attention(q, k, v) + y = vanilla_attention(q, k, v, score_mod=noop) + # y = flex_attention(q, k, v, score_mod=noop) y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) diff --git a/main.py b/main.py index 9285337..301c4f8 100755 --- a/main.py +++ b/main.py @@ -16,8 +16,6 @@ from torch.nn import functional as F import ffutils -import attae - import mygpt import sky, grids, quiz_machine @@ -775,7 +773,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): ###################################################################### -def model_ae_proba_solutions(model, input, log_proba=False): +def model_ae_proba_solutions(model, input, log_probas=False, reduce=True): record = [] for x_0 in input.split(args.batch_size): @@ -791,12 +789,16 @@ def model_ae_proba_solutions(model, input, log_proba=False): loss_per_token = F.cross_entropy( logits.transpose(1, 2), x_0, reduction="none" ) - loss += (loss_per_token * mask_generate).sum(dim=1) + if reduce: + loss += (loss_per_token * mask_generate).sum(dim=1) + else: + loss += loss_per_token * mask_generate + record.append(loss) loss = torch.cat(record, dim=0) - if log_proba: + if log_probas: return -loss else: return (-loss).exp() @@ -811,6 +813,7 @@ def model_ae_argmax_nb_mistakes(model, input): mask_generate = quiz_machine.make_quiz_mask( quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) + logits = logits_hat_x_0_from_random_iteration( model, x_0, mask_generate, prompt_noise=args.prompt_noise ) @@ -996,11 +999,13 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi ###################################################################### +# import attae + models = [] for i in range(args.nb_models): - # model = MyAttentionAE( - model = attae.AttentionAE( + model = MyAttentionAE( + # model = attae.AttentionAE( vocabulary_size=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, @@ -1055,11 +1060,11 @@ def save_badness_statistics( ###################################################################### -def quiz_validation(models, c_quizzes, local_device): +def quiz_validation_1(models, c_quizzes, local_device): nb_have_to_be_correct = args.nb_models // 2 - nb_have_to_be_wrong = args.nb_models // 5 + nb_have_to_be_wrong = 1 - nb_runs = 3 + nb_runs = 1 nb_mistakes_to_be_wrong = 5 record_wrong = [] @@ -1088,6 +1093,78 @@ def quiz_validation(models, c_quizzes, local_device): return to_keep, wrong +def quiz_validation_2(models, c_quizzes, local_device): + nb_have_to_be_correct = 3 + nb_have_to_be_wrong = 1 + nb_runs = 3 + + record_wrong = [] + nb_correct, nb_wrong = 0, 0 + + for i, model in enumerate(models): + assert i == model.id # a bit of paranoia + model = copy.deepcopy(model).to(local_device).eval() + log_probas_max, log_probas_min = None, None + for _ in range(nb_runs): + log_probas = model_ae_proba_solutions( + model, c_quizzes, log_probas=True, reduce=False + ) + log_probas_max = ( + log_probas + if log_probas_max is None + else log_probas.maximum(log_probas_max) + ) + log_probas_min = ( + log_probas + if log_probas_min is None + else log_probas.minimum(log_probas_min) + ) + probas = log_probas.sum(dim=1).exp() + correct = (log_probas_min.exp() <= 0.75).long().sum(dim=1) == 0 + wrong = (log_probas_min.exp() <= 0.1).long().sum(dim=1) >= 3 + record_wrong.append(wrong[:, None]) + nb_correct += correct.long() + nb_wrong += wrong.long() + + to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) + + wrong = torch.cat(record_wrong, dim=1) + + return to_keep, wrong + + +def quiz_validation(models, c_quizzes, local_device): + nb_have_to_be_correct = 3 + nb_have_to_be_wrong = 1 + nb_runs = 3 + + record_wrong = [] + nb_correct, nb_wrong = 0, 0 + + for i, model in enumerate(models): + assert i == model.id # a bit of paranoia + model = copy.deepcopy(model).to(local_device).eval() + log_probas = 0 + for _ in range(nb_runs): + log_probas += model_ae_proba_solutions( + model, c_quizzes, log_probas=True, reduce=False + ) + probas = log_probas.exp() + correct = (probas <= 0.75).long().sum(dim=1) == 0 + wrong = ((probas <= 0.125).long().sum(dim=1) >= 5) & ( + log_probas.sum(dim=1).div(nb_runs).exp() <= 0.5 + ) + record_wrong.append(wrong[:, None]) + nb_correct += correct.long() + nb_wrong += wrong.long() + + to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) + + wrong = torch.cat(record_wrong, dim=1) + + return to_keep, wrong + + def generate_ae_c_quizzes(models, nb, local_device=main_device): # To be thread-safe we must make copies @@ -1305,20 +1382,28 @@ for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() - for gpu in gpus: - t = threading.Thread( - target=thread_generate_ae_c_quizzes, - daemon=True, - args=(models, nb_c_quizzes_to_generate, records, gpu), - ) + if len(gpus) > 1: + for gpu in gpus: + t = threading.Thread( + target=thread_generate_ae_c_quizzes, + daemon=True, + args=(models, nb_c_quizzes_to_generate, records, gpu), + ) - # To get a different sequence between threads - log_string(f"dummy {torch.rand(1)}") - threads.append(t) - t.start() + # To get a different sequence between threads + log_string(f"dummy {torch.rand(1)}") + threads.append(t) + t.start() - for t in threads: - t.join() + for t in threads: + t.join() + + else: + records.append( + generate_ae_c_quizzes( + models, nb_c_quizzes_to_generate, records, gpus[0] + ) + ) time_c_quizzes = int(time.perf_counter() - start_time) @@ -1350,25 +1435,36 @@ for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() - for gpu, model in zip(gpus, weakest_models): + if len(gpus) > 1: + for gpu, model in zip(gpus, weakest_models): + log_string(f"training model {model.id} (accuracy {model.test_accuracy})") + if c_quizzes is None: + c_quizzes_for_this_model = None + else: + c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]] + + t = threading.Thread( + target=one_ae_epoch, + daemon=True, + args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu), + ) + + threads.append(t) + + t.start() + + for t in threads: + t.join() + + else: + model = weakest_models[0] log_string(f"training model {model.id} (accuracy {model.test_accuracy})") if c_quizzes is None: c_quizzes_for_this_model = None else: c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]] - t = threading.Thread( - target=one_ae_epoch, - daemon=True, - args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu), - ) - - threads.append(t) - - t.start() - - for t in threads: - t.join() + one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpus[0]) time_train += int(time.perf_counter() - start_time) -- 2.39.5