From: François Fleuret Date: Mon, 15 Jul 2024 21:12:41 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1e1394d00d3f7046f0e342b01ea96a240fafd72a;p=culture.git Update. --- diff --git a/main.py b/main.py index 9d36aba..5c58beb 100755 --- a/main.py +++ b/main.py @@ -90,7 +90,7 @@ parser.add_argument("--proba_not_understands", type=float, default=0.5) parser.add_argument("--generation_temperature", type=float, default=2) -parser.add_argument("--c_quiz_validation_mode", type=str, default="proba") +parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") parser.add_argument("--dirty_debug", action="store_true", default=False) diff --git a/quiz_machine.py b/quiz_machine.py index f66258a..0f834dc 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -537,10 +537,8 @@ class QuizMachine: seq_logproba = torch.zeros(nb, device=self.device) - # First, we generate the answer at high temperature - - c_quizzes[:, 0] = self.token_backward - c_quizzes[:, 1 + self.answer_len] = self.token_backward + c_quizzes[:, 0] = self.token_forward + c_quizzes[:, 1 + self.prompt_len] = self.token_forward masked_inplace_autoregression( model=model_for_generation, @@ -548,29 +546,11 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes, first=True), seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=False, - device=self.device, - ) - - # Then, we generate the prompt at low temperature - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, temperature=1.0, deterministic_synthesis=False, device=self.device, ) - # Then we return the quizz, and re-generate the response, now - # at low temperature - - c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size,