From 1506fb905b0f83034107e8e8dc336d10bdb1a7a7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 16:09:51 +0200 Subject: [PATCH] Update. --- tasks.py | 52 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/tasks.py b/tasks.py index f6d34a8..cb5900b 100755 --- a/tasks.py +++ b/tasks.py @@ -238,27 +238,46 @@ class World(Task): model, other_models, ): - new_quizzes = torch.empty( + ############################################################### + # Generate quizzes with model + + quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(new_quizzes.size(), 1, device=self.device) + ar_mask = torch.full(quizzes.size(), 1, device=self.device) masked_inplace_autoregression( model, self.batch_size, - new_quizzes, + quizzes, ar_mask, deterministic_synthesis=False, progress_bar_desc="creating quizzes", device=self.device, ) - ar_mask = self.make_ar_mask(new_quizzes) + ############################################################### + # Create the reverse quizzes + + l = self.height * self.width + direction = quizzes[:, l : l + 1] + direction = world.token_forward * ( + direction == world.token_backward + ) + world.token_backward * (direction == world.token_forward) + reverse_quizzes = torch.cat( + [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1 + ) + + ar_mask = self.make_ar_mask(quizzes) + + ############################################################### + # Check how many of the other models can solve them in both + # directions nb_correct = 0 for m in other_models: - result = new_quizzes.clone() + result = quizzes.clone() masked_inplace_autoregression( m, @@ -270,29 +289,24 @@ class World(Task): device=self.device, ) - l = self.height * self.width - direction = new_quizzes[:, l : l + 1] - direction = world.token_forward * ( - direction == world.token_backward - ) + world.token_backward * (direction == world.token_forward) - inverted_quizzes = torch.cat( - [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1 - ) + correct = (quizzes == result).long().min(dim=-1).values - inverted_result = inverted_quizzes.clone() + reverse_result = reverse_quizzes.clone() masked_inplace_autoregression( m, self.batch_size, - inverted_result, + reverse_result, ar_mask, deterministic_synthesis=True, progress_bar_desc="solving reversed quizzes", device=self.device, ) - nb_correct += (new_quizzes == result).long().min(dim=-1).values * ( - inverted_quizzes == inverted_result - ).long().min(dim=-1).values + reverse_correct = ( + (reverse_quizzes == reverse_result).long().min(dim=-1).values + ) + + nb_correct += correct * reverse_correct - return new_quizzes, nb_correct + return quizzes, nb_correct -- 2.39.5