From: François Fleuret Date: Mon, 15 Jul 2024 18:26:15 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=7e8b2d697a7018caf7774a44cdfa67b590141844;p=culture.git Update. --- diff --git a/main.py b/main.py index 02259b2..ff36e98 100755 --- a/main.py +++ b/main.py @@ -368,53 +368,35 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### -# This is the key routine that decides what generated quizzes to keep +def keep_good_quizzes(models, quizzes): + quizzes = quizzes[quiz_machine.non_trivial(quizzes)] + token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes) -# token_logprobas are NxMxT where M is the number of models -# def compute_valid_quizzes_(token_logprobas): -# warnings.warn("validation with uniform constraints", RuntimeWarning) -# l = token_logprobas.min(dim=-1).values.sort(dim=-1).values -# return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) - -# token_logprobas are NxMxT where M is the number of models - - -def compute_valid_quizzes(token_logprobas): l = token_logprobas.sum(dim=-1).sort(dim=-1).values - return (l[:, 0] < math.log(args.proba_not_understands)) & ( + + to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & ( l[:, 1] > math.log(args.proba_understands) ) + if args.dirty_debug: + # warnings.warn("DEBUG", RuntimeWarning) + to_keep = torch.rand(to_keep.size(), device=to_keep.device) < 0.5 -def extract_valid_quizzes_and_logprobas(recorded): - validated_quizzes, validated_logprobas = [], [] - for quizzes, token_logprobas in recorded: - validated_indices = compute_valid_quizzes(token_logprobas) - validated_quizzes.append(quizzes[validated_indices]) - validated_logprobas.append(token_logprobas[validated_indices]) - - if len(validated_quizzes) > 0: - return torch.cat(validated_quizzes, dim=0), torch.cat( - validated_logprobas, dim=0 - ) - else: - return None, None + return quizzes[to_keep] ###################################################################### -def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): +def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): nb_to_create = nb_for_train + nb_for_test - - recorded_quizzes_logprobas = [] - + nb_to_generate_per_iteration = nb_to_create nb_validated = 0 - start_time = time.perf_counter() + recorded = [] - nb_to_generate_per_iteration = nb_to_create + start_time = time.perf_counter() while nb_validated < nb_to_create: model_for_generation = models[torch.randint(len(models), (1,))] @@ -425,19 +407,11 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): temperature=args.generation_temperature, ) - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + c_quizzes = keep_good_quizzes(models, c_quizzes) - if c_quizzes.size(0) > 0: - token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes) - recorded_quizzes_logprobas.append((c_quizzes, token_logproba)) + nb_validated += c_quizzes.size(0) - ( - validated_quizzes, - validated_logprobas, - ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas) - - if validated_quizzes is not None: - nb_validated = validated_quizzes.size(0) + recorded.append(c_quizzes) duration = time.perf_counter() - start_time @@ -454,6 +428,9 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})" ) + validated_quizzes = torch.cat(recorded, dim=0) + + ###################################################################### # store the new c_quizzes which have been validated v_train = validated_quizzes[:nb_for_train] @@ -465,20 +442,12 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False) ###################################################################### - # save images with their logprobas + # save images vq = validated_quizzes[:128] - vl = validated_logprobas[:128] if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}" - filename = os.path.join(args.result_dir, prefix + "_logp.pth") - torch.save(vl, filename) - # with open(file_name, "w") as logp_file: - # for l in vl: - # s = " ".join([str(x.item()) for x in l]) - # logp_file.write(s + "\n") - quiz_machine.save_quiz_illustrations( args.result_dir, prefix, vq, show_part_to_predict=False ) @@ -630,7 +599,7 @@ for n_epoch in range(args.nb_epochs): # re-compute the test errors if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: - create_c_quizzes( + record_new_c_quizzes( models, quiz_machine, nb_for_train=args.nb_new_c_quizzes_for_train,