From: François Fleuret Date: Thu, 4 Jul 2024 16:25:10 +0000 (+0300) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=455161a64dfc7a53d09ff1cd49f590ff9152cc37;p=culture.git Update. --- diff --git a/quizz_machine.py b/quizz_machine.py index 65b6000..62ae8ce 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -346,10 +346,6 @@ class QuizzMachine: .item() ) - self.logger( - f"back_accuracy {n_epoch=} {model.id=} {nb_correct=} {nb_total=}" - ) - n_backward = input[:, 0] == self.token_backward back_input = self.reverse_time(result[n_backward]) @@ -358,11 +354,20 @@ class QuizzMachine: n_backward, 1 : 1 + self.answer_len ] back_nb_total, back_nb_correct = compute_accuracy(back_input) + + self.logger( + f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}" + ) self.logger( - f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct=} {back_nb_total=}" + f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct} / {back_nb_total}" ) + nb_total += back_nb_total nb_correct += back_nb_correct + else: + self.logger( + f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}" + ) else: nb_total = input.size(0) diff --git a/reasoning.py b/reasoning.py index 2874adc..54a4203 100755 --- a/reasoning.py +++ b/reasoning.py @@ -293,7 +293,7 @@ class Reasoning(problem.Problem): X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] - def task_move(self, A, f_A, B, f_B): + def task_translate(self, A, f_A, B, f_B): di, dj = torch.randint(3, (2,)) - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 @@ -406,16 +406,31 @@ class Reasoning(problem.Problem): if n < nb_rec - 1: f_X[i1, j1] = c[-1] + def task_count(self, A, f_A, B, f_B): + N = torch.randint(3, (1,)) + 1 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + nb = torch.randint(self.width, (3,)) + 1 + k = torch.randperm(self.height * self.width)[: nb.sum()] + p = 0 + for n in range(N): + for m in range(nb[n]): + i, j = k[p] % self.height, k[p] // self.height + X[i, j] = c[n] + f_X[n, m] = c[n] + p += 1 + ###################################################################### def generate_prompts_and_answers(self, nb, device="cpu"): tasks = [ self.task_replace_color, - self.task_move, + self.task_translate, self.task_grow, self.task_color_grow, self.task_frame, self.task_detect, + self.task_count, ] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) @@ -476,6 +491,6 @@ if __name__ == "__main__": prompts[:64], answers[:64], # You can add a bool to put a frame around the predicted parts - predicted_prompts[:64], - predicted_answers[:64], + # predicted_prompts[:64], + # predicted_answers[:64], )