From: François Fleuret Date: Sat, 31 Aug 2024 21:53:33 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=56409652dc770ab2a63c5377ca297797daf684e4;p=culture.git Update. --- diff --git a/main.py b/main.py index 879d9fd..34bf920 100755 --- a/main.py +++ b/main.py @@ -964,11 +964,14 @@ def ae_batches( nb, data_structures, local_device, + c_quizzes=None, desc=None, batch_size=args.batch_size, ): + c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")] + full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( - nb, data_structures=data_structures + nb, c_quiz_bags, data_structures=data_structures ) src = zip( @@ -1237,7 +1240,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): ###################################################################### -def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_device): +def one_ae_epoch( + model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device +): model.train().to(local_device) nb_train_samples, acc_train_loss = 0, 0.0 @@ -1247,6 +1252,7 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d args.nb_train_samples, data_structures, local_device, + c_quizzes, "training", ): input = input.to(local_device) @@ -1325,9 +1331,9 @@ def c_quiz_criterion_some(probas): def generate_ae_c_quizzes(models, local_device=main_device): criteria = [ c_quiz_criterion_one_good_one_bad, - c_quiz_criterion_diff, - c_quiz_criterion_two_certains, - c_quiz_criterion_some, + # c_quiz_criterion_diff, + # c_quiz_criterion_two_certains, + # c_quiz_criterion_some, ] for m in models: @@ -1343,9 +1349,10 @@ def generate_ae_c_quizzes(models, local_device=main_device): quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) ) - duration_max = 3600 + duration_max = 4 * 3600 - wanted_nb = 512 + wanted_nb = 10000 + nb_to_save = 128 with torch.autograd.no_grad(): records = [[] for _ in criteria] @@ -1386,7 +1393,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): ) for n, u in enumerate(records): - quizzes = torch.cat(u, dim=0)[:wanted_nb] + quizzes = torch.cat(u, dim=0)[:nb_to_save] filename = f"culture_c_{n_epoch:04d}_{n:02d}.png" # result, predicted_parts, correct_parts = bag_to_tensors(record) @@ -1405,11 +1412,14 @@ def generate_ae_c_quizzes(models, local_device=main_device): # predicted_parts=predicted_parts, # correct_parts=correct_parts, comments=comments, - nrow=8, ) log_string(f"wrote {filename}") + a = [torch.cat(u, dim=0) for u in records] + + return torch.cat(a, dim=0).unique(dim=0) + ###################################################################### @@ -1453,6 +1463,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### +last_n_epoch_c_quizzes = 0 + +c_quizzes = None + for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() @@ -1476,8 +1490,17 @@ for n_epoch in range(current_epoch, args.nb_epochs): # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device) # exit(0) - if min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes: - generate_ae_c_quizzes(models, local_device=main_device) + if ( + min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes + and n_epoch >= last_n_epoch_c_quizzes + 10 + ): + last_n_epoch_c_quizzes = n_epoch + c_quizzes = generate_ae_c_quizzes(models, local_device=main_device) + + if c_quizzes is None: + log_string("no_c_quiz") + else: + log_string(f"nb_c_quizzes {c_quizzes.size(0)}") ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1492,7 +1515,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): t = threading.Thread( target=one_ae_epoch, daemon=True, - args=(model, models, quiz_machine, n_epoch, gpu), + args=(model, models, quiz_machine, n_epoch, c_quizzes, gpu), ) threads.append(t)