parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument("--generation_temperature", type=float, default=1.0)
+
+parser.add_argument("--stochastic_validation", action="store_true", default=False)
+
+######################################################################
+
parser.add_argument("--sky_height", type=int, default=6)
parser.add_argument("--sky_width", type=int, default=8)
c_quizzes = quizz_machine.generate_quizzes(
nb_to_create,
model_for_generation=model_for_generation,
+ temperature=args.generation_temperature,
)
nb_correct, seq_logproba = quizz_machine.compute_correctness(
- c_quizzes, models, both_directions=args.both_directions
+ c_quizzes,
+ models,
+ both_directions=args.both_directions,
+ deterministic_validation=not args.stochastic_validation,
)
for n, l in zip(nb_correct, seq_logproba):
)
def compute_correctness(
- self, c_quizzes, models_for_validation, both_directions=False
+ self,
+ c_quizzes,
+ models_for_validation,
+ both_directions=False,
+ deterministic_validation=True,
):
reversed_c_quizzes = self.reverse_time(c_quizzes)
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
temperature=1.0,
- deterministic_synthesis=True,
+ deterministic_synthesis=deterministic_validation,
# progress_bar_desc="solving c_quizzes",
device=self.device,
)
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
temperature=1.0,
- deterministic_synthesis=True,
+ deterministic_synthesis=deterministic_validation,
# progress_bar_desc="solving reversed c_quizzes",
device=self.device,
)
###############################################################
- def generate_quizzes(self, nb, model_for_generation):
+ def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device)
- temperature = 10.0
-
# First, we generate the answer at high temperature
c_quizzes[:, 0] = self.token_backward
device=self.device,
)
- # Then, we generate the prompt deterministically
+ # Then, we generate the prompt at low temperature
masked_inplace_autoregression(
model=model_for_generation,
input=c_quizzes,
ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=True,
+ temperature=1 / temperature,
+ deterministic_synthesis=False,
device=self.device,
)
# Then we return the quizz, and re-generate the response, now
- # deterministically
+ # at low temperature
c_quizzes = self.reverse_time(c_quizzes)
input=c_quizzes,
ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=True,
+ temperature=1 / temperature,
+ deterministic_synthesis=False,
device=self.device,
)
"_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
)
- def nb_token_values(self):
- return len(self.colors)
-
def __init__(
self,
height=6,
######################################################################
- def generate_prompts_and_answers(self, nb):
- frame_sequences = self.generate_frame_sequences(nb)
- frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
- prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
- answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
- return prompts, answers
-
- ######################################################################
-
def frame2img(self, x, scale=15):
x = x.reshape(x.size(0), self.height, -1)
m = torch.logical_and(
img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
)
+ ######################################################################
+
+ def nb_token_values(self):
+ return len(self.colors)
+
+ def generate_prompts_and_answers(self, nb):
+ frame_sequences = self.generate_frame_sequences(nb)
+ frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
+ prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
+ answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+ return prompts, answers
+
def save_quizzes(
self,
result_dir,