From 015c9f73a588666b2b20887623a4ed1e0200873e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jul 2024 06:04:03 +0200 Subject: [PATCH] Update. --- grids.py | 17 +--- quiz_machine.py | 212 ------------------------------------------------ 2 files changed, 3 insertions(+), 226 deletions(-) diff --git a/grids.py b/grids.py index b531eb9..ba0131d 100755 --- a/grids.py +++ b/grids.py @@ -177,28 +177,17 @@ class Grids(problem.Problem): ) else: flipped_from_forward = torch.cat( - [ - quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1], - quizzes[:, 0 * (S + 1) : 2 * (S + 1) + S + 1], - quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1], - quizzes[:, 2 * (S + 1) : 0 * (S + 1) + S + 1], - ], + [quizzes[:, 3 * (S + 1) :], quizzes[:, : 3 * (S + 1)]], dim=1, ) flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward flipped_from_backward = torch.cat( - [ - quizzes[:, 1 * (S + 1) : 3 * (S + 1) + S + 1], - quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1], - quizzes[:, 3 * (S + 1) : 1 * (S + 1) + S + 1], - quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1], - ], - dim=1, + [quizzes[:, S + 1 :], quizzes[:, : S + 1]], dim=1 ) flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward - m = (flipped[:, 0] == self.token_forward).long() + m = (quizzes[:, 0] == self.token_forward).long()[:, None] flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward diff --git a/quiz_machine.py b/quiz_machine.py index 182e9ff..b1f6be1 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -603,215 +603,3 @@ class QuizMachine: return c_quizzes.to("cpu") ###################################################################### - - def generate_c_quizzes_fixed_point( - self, - nb, - model_for_generation, - p2a_only=False, - temperature_hot=1.0, - temperature_cold=1.0, - ): - c_quizzes = torch.empty( - nb, - self.prompt_len + self.answer_len, - device=self.device, - dtype=torch.int64, - ) - - seq_logproba = torch.zeros(nb, device=self.device) - - lt_noisy = lambda s, logits: logits / temperature_hot - lt_clean = lambda s, logits: logits / temperature_cold - - c_quizzes[...] = self.problem.token_backward - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_before", c_quizzes) - - c_quizzes = self.problem.p_a_flip(c_quizzes) - - while True: - print("ITERATION") - - c_quizzes = self.problem.p_a_flip(c_quizzes) - - pred = c_quizzes.clone() - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) - - c_quizzes = self.problem.p_a_flip(c_quizzes) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) - - if pred[202:].equal(c_quizzes[202:]): - break - - self.save_quiz_illustrations("/tmp", f"c_quizzes_after", c_quizzes) - - exit(0) - - return c_quizzes.to("cpu") - - ###################################################################### - - def generate_c_quizzes_mixing( - self, - nb, - model_for_generation, - p2a_only=False, - temperature_hot=1.0, - temperature_cold=1.0, - ): - c_quizzes = torch.empty( - nb, - self.prompt_len + self.answer_len, - device=self.device, - dtype=torch.int64, - ) - - c_quizzes_1 = torch.empty( - nb, - self.prompt_len + self.answer_len, - device=self.device, - dtype=torch.int64, - ) - - c_quizzes_2 = torch.empty( - nb, - self.prompt_len + self.answer_len, - device=self.device, - dtype=torch.int64, - ) - - seq_logproba = torch.zeros(nb, device=self.device) - - lt_noisy = lambda s, logits: logits / temperature_hot - lt_clean = lambda s, logits: logits / temperature_cold - - ###################################################################### - - c_quizzes_1[...] = self.problem.token_backward - ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0") - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes_1, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1) - - c_quizzes_2[...] = self.problem.token_backward - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes_2, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2) - - h = len(model_for_generation.trunk) // 2 - - with torch.autograd.no_grad(): - t = model_for_generation.training - model_for_generation.eval() - - bs1 = model_for_generation.partial_forward( - mygpt.BracketedSequence(c_quizzes_1), end_layer=h - ) - bs2 = model_for_generation.partial_forward( - mygpt.BracketedSequence(c_quizzes_2), end_layer=h - ) - - alpha = 0.1 - - output = model_for_generation.partial_forward( - mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x), - start_layer=h, - ).x - - dist = torch.distributions.categorical.Categorical(logits=output) - c_quizzes[...] = dist.sample() - - c_quizzes[...] = ( - ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward - ) - - model_for_generation.train(t) - - self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes) - - ###################################################################### - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes) - - c_quizzes = self.problem.p_a_flip(c_quizzes) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes) - - print("DONE") - exit(0) - - return c_quizzes.to("cpu") -- 2.39.5