From: François Fleuret Date: Sat, 31 Aug 2024 20:00:33 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=4b8677664ecb16aa86f7acbdfe8bc71259c96bbb;p=culture.git Update. --- diff --git a/main.py b/main.py index ab625cc..879d9fd 100755 --- a/main.py +++ b/main.py @@ -1316,11 +1316,18 @@ def c_quiz_criterion_two_certains(probas): return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5) +def c_quiz_criterion_some(probas): + return ((probas >= 0.8).long().sum(dim=1) >= 1) & ( + (probas <= 0.2).long().sum(dim=1) >= 1 + ) + + def generate_ae_c_quizzes(models, local_device=main_device): criteria = [ c_quiz_criterion_one_good_one_bad, c_quiz_criterion_diff, c_quiz_criterion_two_certains, + c_quiz_criterion_some, ] for m in models: @@ -1336,7 +1343,9 @@ def generate_ae_c_quizzes(models, local_device=main_device): quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) ) - duration_max = 600 # 3 * 3600 + duration_max = 3600 + + wanted_nb = 512 with torch.autograd.no_grad(): records = [[] for _ in criteria] @@ -1345,7 +1354,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): while ( time.perf_counter() < start_time + duration_max - and min([bag_len(bag) for bag in records]) < 128 + and min([bag_len(bag) for bag in records]) < wanted_nb ): bl = [bag_len(bag) for bag in records] log_string(f"bag_len {bl}") @@ -1353,39 +1362,53 @@ def generate_ae_c_quizzes(models, local_device=main_device): model = models[torch.randint(len(models), (1,)).item()] result = ae_generate(model, template, mask_generate, noise_proba) - probas = torch.cat( - [model_ae_proba_solutions(model, result)[:, None] for model in models], - dim=1, - ) + to_keep = quiz_machine.problem.trivial(result) == False + result = result[to_keep] - for c, r in zip(criteria, records): - q = result[c(probas)] - if q.size(0) > 0: - r.append(q) + if result.size(0) > 0: + probas = torch.cat( + [ + model_ae_proba_solutions(model, result)[:, None] + for model in models + ], + dim=1, + ) - for n, u in enumerate(records): - quizzes = torch.cat(u, dim=0)[:128] - filename = f"culture_{n_epoch:04d}_{n:02d}.png" + for c, r in zip(criteria, records): + q = result[c(probas)] + if q.size(0) > 0: + r.append(q) - # result, predicted_parts, correct_parts = bag_to_tensors(record) + duration = time.perf_counter() - start_time - # l = [model_ae_proba_solutions(model, result) for model in models] - # probas = torch.cat([x[:, None] for x in l], dim=1) - # comments = [] + log_string( + f"generate_c_quizz_generation_speed {int(3600 * wanted_nb / duration)}/h" + ) - # for l in probas: - # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + for n, u in enumerate(records): + quizzes = torch.cat(u, dim=0)[:wanted_nb] + filename = f"culture_c_{n_epoch:04d}_{n:02d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - # comments=comments, - ) + # result, predicted_parts, correct_parts = bag_to_tensors(record) - log_string(f"wrote {filename}") + l = [model_ae_proba_solutions(model, quizzes) for model in models] + probas = torch.cat([x[:, None] for x in l], dim=1) + comments = [] + + for l in probas: + comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=quizzes, + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + comments=comments, + nrow=8, + ) + + log_string(f"wrote {filename}") ######################################################################