From: François Fleuret Date: Wed, 17 Jul 2024 03:14:05 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=9b6d18308d8a1b19023023b00ceb6a2feea50ab1;p=culture.git Update. --- diff --git a/main.py b/main.py index ca1e9b5..76db5e2 100755 --- a/main.py +++ b/main.py @@ -82,13 +82,13 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--generation_temperature", type=float, default=2) +parser.add_argument("--generation_temperature", type=float, default=2.5) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") @@ -367,9 +367,12 @@ def one_epoch(model, quiz_machine, local_device=main_device): acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - hard_w_quizzes.append( - (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu")) - ) + n_forward = input[:, 0] == self.token_forward + to_store = from_w & n_forward + if to_store.any(): + hard_w_quizzes.append( + (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu")) + ) nb_train_samples += input.size(0) @@ -384,11 +387,11 @@ def one_epoch(model, quiz_machine, local_device=main_device): run_tests(model, quiz_machine, deterministic_synthesis=False) - threshold = torch.cat([x[1] for x in hard_w_quizzes], dim=0).sort().values + threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values threshold = threshold[threshold.size(0) // 2] model.hard_w_quizzes = torch.cat( - [x[0][x[1] >= threshold] for x in hard_w_quizzes], dim=0 + [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 ) model.to(main_device) diff --git a/quiz_machine.py b/quiz_machine.py index 32b3f7e..1168921 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -350,9 +350,7 @@ class QuizMachine: ###################################################################### - def produce_results( - self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 - ): + def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis): def compute_accuracy(input, log_prefix=None): input = input.to(self.device) ar_mask = self.make_ar_mask(input) @@ -400,14 +398,15 @@ class QuizMachine: return result, correct - # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") - test_result, test_correct = compute_accuracy( - model.test_w_quizzes[:nmax], log_prefix="test" + model.test_w_quizzes[:2000], log_prefix="test" ) - main_test_accuracy = test_correct.sum() / test_correct.size(0) - # self.logger(f"main_test_accuracy {n_epoch} model {model.id} {main_test_accuracy}") + n_test_forward = model.test_w_quizzes[:, 0] == self.token_forward + + forward_test_correct = test_correct[n_test_forward] + + main_test_accuracy = forward_test_correct.sum() / forward_test_correct.size(0) ############################## @@ -459,6 +458,9 @@ class QuizMachine: else: input[...] = self.generate_token_sequences(input.size(0)) + if not forward_only: + self.reverse_random_half_in_place(input) + ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True):