From: François Fleuret Date: Fri, 5 Jul 2024 22:10:41 +0000 (+0300) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ed64a064ef1d8d3e53c7961480ffafdd516ea984;p=culture.git Update. --- diff --git a/quizz_machine.py b/quizz_machine.py index 632c9ae..717e8ac 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -325,7 +325,7 @@ class QuizzMachine: def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 ): - def compute_accuracy(input): + def compute_accuracy(input, log_prefix=None): ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -342,96 +342,56 @@ class QuizzMachine: device=self.device, ) - if self.back_accuracy: - # If back_accuracy is True, we compute the accuracy on - # the backward quizzes not by counting how many time - # the real prompt A is equal to the reconstructed - # prompt A*, but how many time the answers B* computed - # from A* is equal to the correct answer. So we look - # for the accuracy of A->B*=B for the forward, but for - # the backward we look at B->A*->B*=B instead of B->A*=A - - n_forward = input[:, 0] == self.token_forward - nb_total = input[n_forward].size(0) - nb_correct = ( - (input[n_forward] == result[n_forward]) - .long() - .min(dim=1) - .values.sum() - .item() - ) + correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) - n_backward = input[:, 0] == self.token_backward - back_input = self.reverse_time(result[n_backward]) + n_forward = input[:, 0] == self.token_forward + n_backward = input[:, 0] == self.token_backward - if back_input.size(0) > 0: - back_input[:, 2 + self.prompt_len :] = input[ - n_backward, 1 : 1 + self.answer_len - ] - back_nb_total, back_nb_correct = compute_accuracy(back_input) - - self.logger( - f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}" - ) - self.logger( - f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct} / {back_nb_total}" - ) - - nb_total += back_nb_total - nb_correct += back_nb_correct - else: - self.logger( - f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}" - ) + correct[n_forward] = ( + (input[n_forward] == result[n_forward]).long().min(dim=1).values + ) - else: - nb_total = input.size(0) - nb_correct = (input == result).long().min(dim=1).values.sum() + if self.back_accuracy and n_backward.any(): + # accuracy of B->A*->B*=B instead of B->A*=A + back_input = self.reverse_time(result[n_backward]) + back_input[:, 2 + self.prompt_len :] = input[ + n_backward, 1 : 1 + self.answer_len + ] + result[n_backward], correct[n_backward] = compute_accuracy(back_input) - return nb_total, nb_correct + if log_prefix is not None: + nb_correct = correct[n_forward].sum() + nb_total = correct[n_forward].size(0) + back_nb_correct = correct[n_backward].sum() + back_nb_total = correct[n_backward].size(0) - train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax]) + self.logger( + f"accuracy {log_prefix} {n_epoch} {model.id=} {nb_correct} / {nb_total}" + ) - self.logger( - f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" - ) + self.logger( + f"back_accuracy {log_prefix} {n_epoch} {model.id=} {back_nb_correct} / {back_nb_total}" + ) - test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax]) + return result, correct - self.logger( - f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + + result, correct = compute_accuracy( + self.test_w_quizzes[:nmax], log_prefix="test" ) - main_test_accuracy = test_nb_correct / test_nb_total + main_test_accuracy = correct.sum() / correct.size(0) self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") ############################## - input = self.test_w_quizzes[:96] - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - seq_logproba = torch.empty(input.size(0), device=self.device) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) - - mistakes = (input == result).flatten(1).long().min(dim=1).values * 2 - 1 - self.save_quizzes( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=result[:72], show_to_be_predicted=True, - mistakes=mistakes[:72], + mistakes=correct[:72] * 2 - 1, ) return main_test_accuracy diff --git a/reasoning.py b/reasoning.py index 5499bdf..aa566b0 100755 --- a/reasoning.py +++ b/reasoning.py @@ -588,6 +588,38 @@ class Reasoning(problem.Problem): X[i, j] = c[1] f_X[0:2, 0:2] = c[1] + def task_islands(self, A, f_A, B, f_B): + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + if ( + i == 0 + or i == self.height - 1 + or j == 0 + or j == self.width - 1 + or X[i, j] == 1 + ): + break + while True: + di, dj = torch.randint(3, (2,)) - 1 + if abs(di) + abs(dj) > 0: + break + X[i, j] = 1 + while True: + i, j = i + di, j + dj + if i < 0 or i >= self.height or j < 0 or j >= self.width: + break + b = ( + i == 0 + or i == self.height - 1 + or j == 0 + or j == self.width - 1 + or X[i, j] == 1 + ) + X[i, j] = 1 + if b: + break + ###################################################################### def all_tasks(self): @@ -602,6 +634,7 @@ class Reasoning(problem.Problem): self.task_trajectory, self.task_bounce, self.task_scale, + self.task_islands, ] def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): @@ -657,21 +690,23 @@ if __name__ == "__main__": reasoning = Reasoning() - for t in reasoning.all_tasks(): + for t in [reasoning.task_islands]: # reasoning.all_tasks(): print(t.__name__) prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t]) reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1) exit(0) + nb = 72 + start_time = time.perf_counter() prompts, answers = reasoning.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - # m = torch.randint(2, (prompts.size(0),)) - # predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - # predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + m = torch.randint(2, (prompts.size(0),)) + predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) reasoning.save_quizzes( "/tmp", @@ -679,6 +714,6 @@ if __name__ == "__main__": prompts[:nb], answers[:nb], # You can add a bool to put a frame around the predicted parts - # predicted_prompts[:nb], - # predicted_answers[:nb], + predicted_prompts[:nb], + predicted_answers[:nb], )