From: François Fleuret Date: Tue, 2 Jul 2024 16:33:15 +0000 (+0300) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=64abc9f3a07a8211f308271fde7d8f876a968ab5;p=culture.git Update. --- diff --git a/main.py b/main.py index 7b8b642..918f75d 100755 --- a/main.py +++ b/main.py @@ -93,6 +93,12 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--dirty_debug", action="store_true", default=False) +parser.add_argument("--generation_temperature", type=float, default=1.0) + +parser.add_argument("--stochastic_validation", action="store_true", default=False) + +###################################################################### + parser.add_argument("--sky_height", type=int, default=6) parser.add_argument("--sky_width", type=int, default=8) @@ -411,10 +417,14 @@ def create_c_quizzes( c_quizzes = quizz_machine.generate_quizzes( nb_to_create, model_for_generation=model_for_generation, + temperature=args.generation_temperature, ) nb_correct, seq_logproba = quizz_machine.compute_correctness( - c_quizzes, models, both_directions=args.both_directions + c_quizzes, + models, + both_directions=args.both_directions, + deterministic_validation=not args.stochastic_validation, ) for n, l in zip(nb_correct, seq_logproba): diff --git a/quizz_machine.py b/quizz_machine.py index 470b095..9b64941 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -322,7 +322,11 @@ class QuizzMachine: ) def compute_correctness( - self, c_quizzes, models_for_validation, both_directions=False + self, + c_quizzes, + models_for_validation, + both_directions=False, + deterministic_validation=True, ): reversed_c_quizzes = self.reverse_time(c_quizzes) @@ -349,7 +353,7 @@ class QuizzMachine: ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], temperature=1.0, - deterministic_synthesis=True, + deterministic_synthesis=deterministic_validation, # progress_bar_desc="solving c_quizzes", device=self.device, ) @@ -366,7 +370,7 @@ class QuizzMachine: ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], temperature=1.0, - deterministic_synthesis=True, + deterministic_synthesis=deterministic_validation, # progress_bar_desc="solving reversed c_quizzes", device=self.device, ) @@ -385,7 +389,7 @@ class QuizzMachine: ############################################################### - def generate_quizzes(self, nb, model_for_generation): + def generate_quizzes(self, nb, model_for_generation, temperature=1.0): c_quizzes = torch.empty( nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) @@ -398,8 +402,6 @@ class QuizzMachine: seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device) - temperature = 10.0 - # First, we generate the answer at high temperature c_quizzes[:, 0] = self.token_backward @@ -415,7 +417,7 @@ class QuizzMachine: device=self.device, ) - # Then, we generate the prompt deterministically + # Then, we generate the prompt at low temperature masked_inplace_autoregression( model=model_for_generation, @@ -423,13 +425,13 @@ class QuizzMachine: input=c_quizzes, ar_mask=ar_mask_second, seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=True, + temperature=1 / temperature, + deterministic_synthesis=False, device=self.device, ) # Then we return the quizz, and re-generate the response, now - # deterministically + # at low temperature c_quizzes = self.reverse_time(c_quizzes) @@ -439,8 +441,8 @@ class QuizzMachine: input=c_quizzes, ar_mask=ar_mask_second, seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, + temperature=1 / temperature, + deterministic_synthesis=False, device=self.device, ) diff --git a/sky.py b/sky.py index 2183cf1..040ec67 100755 --- a/sky.py +++ b/sky.py @@ -42,9 +42,6 @@ class Sky(problem.Problem): "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" ) - def nb_token_values(self): - return len(self.colors) - def __init__( self, height=6, @@ -155,15 +152,6 @@ class Sky(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) - prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) - answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) - return prompts, answers - - ###################################################################### - def frame2img(self, x, scale=15): x = x.reshape(x.size(0), self.height, -1) m = torch.logical_and( @@ -250,6 +238,18 @@ class Sky(problem.Problem): img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0 ) + ###################################################################### + + def nb_token_values(self): + return len(self.colors) + + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) + prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) + return prompts, answers + def save_quizzes( self, result_dir,