From 207a4072f6e421bb746f14d5e3a2065ad45ce13c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 17:58:38 +0200 Subject: [PATCH] Update. --- main.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index a7f9c9e..c4ecc49 100755 --- a/main.py +++ b/main.py @@ -497,7 +497,7 @@ def ae_generate(model, nb, local_device=main_device): all_changed = torch.full((all_input.size(0),), True, device=all_input.device) for it in range(args.diffusion_nb_iterations): - log_string(f"nb_changed {all_changed.long().sum().item()}") + # log_string(f"nb_changed {all_changed.long().sum().item()}") if not all_changed.any(): break @@ -892,9 +892,6 @@ if args.quizzes is not None: c_quizzes = None -time_c_quizzes = 0 -time_train = 0 - ###################################################################### @@ -980,36 +977,33 @@ for n_epoch in range(current_epoch, args.nb_epochs): nb_gpus = len(gpus) nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus - (c_quizzes,) = multithread_execution( + (new_c_quizzes,) = multithread_execution( generate_c_quizzes, [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], ) save_quiz_image( models, - c_quizzes[:256], + new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png", solvable_only=False, ) save_quiz_image( models, - c_quizzes[:256], + new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}_solvable.png", solvable_only=True, ) - u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:] - i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices + log_string(f"generated_c_quizzes {new_c_quizzes.size()=}") - save_quiz_image( - models, - c_quizzes[i][:256], - f"culture_c_quiz_{n_epoch:04d}_solvable_high_delta.png", - solvable_only=True, + c_quizzes = ( + new_c_quizzes + if c_quizzes is None + else torch.cat([c_quizzes, new_c_quizzes]) ) - - log_string(f"generated_c_quizzes {c_quizzes.size()=}") + c_quizzes = c_quizzes[-args.nb_train_samples :] for model in models: model.test_accuracy = 0 -- 2.39.5