From 7b8bd914458b34d793a5bacb961c882862a967a1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 08:57:27 +0200 Subject: [PATCH] Update. --- main.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 66 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 6cbb2c4..4488a70 100755 --- a/main.py +++ b/main.py @@ -383,7 +383,7 @@ data_structures = [ def masked_cross_entropy(output, targets, masks): loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") - return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum() + return (loss_per_token * masks).mean() ###################################################################### @@ -492,6 +492,8 @@ def prioritized_rand(low): def generate(model, nb, local_device=main_device): + model.eval().to(local_device) + all_input = quiz_machine.pure_noise(nb, local_device) all_masks = all_input.new_full(all_input.size(), 1) @@ -622,7 +624,12 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): correct_parts=correct_parts[:128], ) - model.test_accuracy = correct.sum() / quizzes.size(0) + nb_correct, nb_total = correct.sum(), quizzes.size(0) + model.test_accuracy = nb_correct / nb_total + + log_string( + f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)" + ) # generate @@ -634,6 +641,22 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): ) +###################################################################### + + +class TokenCat(nn.Module): + def __init__(self, m, n): + super().__init__() + self.m = m + self.n = n + + def forward(self, x): + u = torch.cat([x.new_zeros(x.size(0), self.n), x], dim=1) + u = self.m(u) + u = u[:, self.n :] + return u + + ###################################################################### import attae @@ -651,6 +674,9 @@ for i in range(args.nb_models): dropout=args.dropout, ).to(main_device) + # if i < args.nb_models//2: + # model = TokenCat(model, 10) + # model = torch.compile(model) model.id = i @@ -748,7 +774,7 @@ def quiz_validation_( ###################################################################### -def generate_c_quizzes(models, nb, local_device=main_device): +def generate_c_quizzes_(models, nb, local_device=main_device): # To be thread-safe we must make copies def copy_for_inference(model): @@ -842,6 +868,41 @@ def generate_c_quizzes(models, nb, local_device=main_device): ###################################################################### +def generate_c_quizzes(models, nb, local_device=main_device): + record = [] + nb_validated = 0 + while nb_validated < nb: + model = models[torch.randint(len(models), (1,)).item()] + model = copy.deepcopy(model).to(local_device).eval() + generator_id = model.id + + c_quizzes = generate( + moel=copy_for_inference(model), + nb=args.physical_batch_size, + local_device=local_device, + ) + + nb_correct, nb_wrong = 0, 0 + for i, model in enumerate(models): + model = copy.deepcopy(model).to(local_device).eval() + result = predict_full(model, c_quizzes, local_device=local_device) + nb_mistakes = (result != c_quizzes).long().sum(dim=1) + nb_correct += (nb_mistakes == 0).long() + nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong + + to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( + nb_wrong >= args.nb_have_to_be_wrong + ) + + nb_validated += to_keep.long().sum() + record.append(c_quizzes[to_keep]) + + return torch.cat(record) + + +###################################################################### + + def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False): l = [] @@ -1094,17 +1155,15 @@ for n_epoch in range(current_epoch, args.nb_epochs): ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] - # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], - multithread_execution( one_complete_epoch, [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)], ) - # -------------------------------------------------------------------- - save_models(models) + # -------------------------------------------------------------------- + duration = time.perf_counter() - start_time str_duration = "" if duration >= 60: -- 2.39.5