task,
nb_for_train=1000,
nb_for_test=100,
- desired_average_logits=None,
+ min_ave_seq_logproba=None,
):
kept = []
while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
nb_to_generate = 4 * (nb_for_train + nb_for_test)
- new_c_quizzes, nb_correct, average_logits = task.create_c_quizzes(
+ new_c_quizzes, nb_correct, ave_seq_logproba = task.create_c_quizzes(
n_epoch=n_epoch,
result_dir=args.result_dir,
logger=log_string,
nb=nb_to_generate,
model=model,
other_models=other_models,
- desired_average_logits=desired_average_logits,
+ min_ave_seq_logproba=min_ave_seq_logproba,
)
- sum_logits += new_c_quizzes.size(0) * average_logits
+ sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
sum_nb_c_quizzes += new_c_quizzes.size(0)
to_keep = new_c_quizzes[nb_correct == len(other_models) - 1]
######################################################################
-desired_average_logits = None
+min_ave_seq_logproba = None
for n_epoch in range(args.nb_epochs):
log_string(f"--- epoch {n_epoch} ----------------------------------------")
other_models = models.copy()
other_models.remove(model)
- average_logits = create_c_quizzes(
+ ave_seq_logproba = create_c_quizzes(
model,
other_models,
task,
nb_for_train=nb_new_c_quizzes_for_train,
nb_for_test=nb_new_c_quizzes_for_test,
- desired_average_logits=desired_average_logits,
+ min_ave_seq_logproba=min_ave_seq_logproba,
)
# We keep the first average logits as a reference
- if desired_average_logits is None:
- desired_average_logits = average_logits
+ if min_ave_seq_logproba is None:
+ min_ave_seq_logproba = ave_seq_logproba
else:
log_string(
- f"desired_average_logits {desired_average_logits} average_logits {average_logits}"
+ f"min_ave_seq_logproba {min_ave_seq_logproba} ave_seq_logproba {ave_seq_logproba}"
)
# We update everyone
self,
input,
ar_mask,
- summed_logits,
+ seq_logproba,
temperature=1.0,
deterministic_synthesis=False,
forbidden_tokens=None,
else:
dist = torch.distributions.categorical.Categorical(logits=logits)
t_next = dist.sample()
- if summed_logits is not None:
- summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum(
- dim=-1
- )
+
+ if seq_logproba is not None:
+ all_t = torch.arange(t_next.size(0))
+ seq_logproba += logits[all_t, t_next].sum(dim=-1)
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
batch_size,
input,
ar_mask,
- summed_logits,
+ seq_logproba,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
- summed_logits=summed_logits,
+ seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
- summed_logits=None,
+ seq_logproba=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
- summed_logits=None,
+ seq_logproba=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
nb,
model,
other_models,
- desired_average_logits=None,
+ min_ave_seq_logproba=None,
):
###############################################################
# Generate quizzes with model
)
ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
- summed_logits = torch.empty(nb, device=self.device)
+ seq_logproba = torch.empty(nb, device=self.device)
temperature = 1
d_temperature = 1
while True:
- summed_logits[...] = 0
+ seq_logproba[...] = 0
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=c_quizzes,
ar_mask=ar_mask,
- summed_logits=summed_logits,
+ seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
progress_bar_desc="sampling c_quizzes",
device=self.device,
)
- average_logits = summed_logits.mean()
+ ave_seq_logproba = seq_logproba.mean()
- logger(f"{average_logits=} {desired_average_logits=}")
+ logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
- if desired_average_logits is None:
+ if min_ave_seq_logproba is None:
break
# Oh man that's ugly
- if average_logits < desired_average_logits * 1.1:
+ if ave_seq_logproba < min_ave_seq_logproba * 1.1:
if d_temperature > 0:
d_temperature *= -0.5
temperature += d_temperature
- elif average_logits > desired_average_logits:
+ elif ave_seq_logproba > min_ave_seq_logproba:
if d_temperature < 0:
d_temperature *= -0.5
temperature += d_temperature
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
- summed_logits=None,
+ seq_logproba=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving c_quizzes",
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
- summed_logits=None,
+ seq_logproba=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving reversed c_quizzes",
nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
- # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
- # with open(filename, "w") as f:
- # for k in nb_correct:
- # f.write(f"{k}\n")
-
- return c_quizzes, nb_correct, summed_logits.mean()
+ return c_quizzes, nb_correct, seq_logproba.mean()