From: François Fleuret Date: Sun, 15 Sep 2024 12:26:49 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f068c8ba65880d4829198eba057b145f1e2c4cf0;p=culture.git Update. --- diff --git a/diffusion.py b/diffusion.py index 8c6e08d..2dc5861 100755 --- a/diffusion.py +++ b/diffusion.py @@ -52,13 +52,40 @@ class Diffuser: ###################################################################### + def make_mask_hints(self, mask_generate, nb_hints): + if nb_hints == 0: + mask_hints = None + else: + u = ( + torch.rand(mask_generate.size(), device=mask_generate.device) + * mask_generate + ) + mask_hints = ( + u > u.sort(dim=1, descending=True).values[:, nb_hints, None] + ).long() + + return mask_hints + # This function gets a clean target x_0, and a mask indicating which # part to generate (conditionnaly to the others), and returns the # logits starting from a x_t|X_0=x_0 picked at random with t random def logits_hat_x_0_from_random_iteration( - self, model, x_0, mask_generate, prompt_noise=0.0 + self, model, x_0, mask_generate, nb_hints=0, prompt_noise=0.0 ): + noise = self.mu_T_sampler(x_0.size(), device=x_0.device) + + single_iteration = ( + mask_generate.sum(dim=1) < mask_generate.size(1) // 2 + ).long()[:, None] + + mask_hints = self.make_mask_hints(mask_generate, nb_hints) + + if mask_hints is None: + mask_start = mask_generate + else: + mask_start = mask_generate * (1 - mask_hints) + # We favor iterations near the clean signal probs_iterations = 0.1 ** torch.linspace( @@ -71,7 +98,9 @@ class Diffuser: t = dist.sample() + 1 - x_t = self.sample_x_t_given_x_0(x_0, t) + x_t = single_iteration * noise + ( + 1 - single_iteration + ) * self.sample_x_t_given_x_0(x_0, t) # Only the part to generate is degraded, the rest is a perfect # noise-free conditionning @@ -99,13 +128,15 @@ class Diffuser: ###################################################################### - def ae_generate(self, model, x_0, mask_generate, mask_hints=None): + def generate(self, model, x_0, mask_generate, nb_hints=0): noise = self.mu_T_sampler(x_0.size(), device=x_0.device) single_iteration = ( mask_generate.sum(dim=1) < mask_generate.size(1) // 2 ).long()[:, None] + mask_hints = self.make_mask_hints(mask_generate, nb_hints) + if mask_hints is None: mask_start = mask_generate else: diff --git a/main.py b/main.py index 1461ab1..d508c97 100755 --- a/main.py +++ b/main.py @@ -79,6 +79,12 @@ parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--reboot", action="store_true", default=False) +parser.add_argument("--nb_have_to_be_correct", type=int, default=3) + +parser.add_argument("--nb_have_to_be_wrong", type=int, default=1) + +parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5) + # ---------------------------------- parser.add_argument("--model", type=str, default="37M") @@ -388,7 +394,7 @@ data_structures = [ ###################################################################### -def model_ae_proba_solutions(model, input, log_probas=False, reduce=True): +def model_proba_solutions(model, input, log_probas=False, reduce=True): record = [] for x_0 in input.split(args.batch_size): @@ -422,7 +428,7 @@ def model_ae_proba_solutions(model, input, log_probas=False, reduce=True): ###################################################################### -def ae_batches( +def batches( quiz_machine, nb, data_structures, @@ -469,7 +475,7 @@ def NTC_masked_cross_entropy(output, targets, mask): ###################################################################### -def run_ae_test( +def run_test( model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None ): if prefix is None: @@ -484,7 +490,7 @@ def run_ae_test( nb_test_samples, acc_test_loss = 0, 0.0 - for x_0, mask_generate in ae_batches( + for x_0, mask_generate in batches( quiz_machine, args.nb_test_samples, data_structures, @@ -509,7 +515,7 @@ def run_ae_test( nb_correct, nb_total, record_d, record_nd = 0, 0, [], [] - for x_0, mask_generate in ae_batches( + for x_0, mask_generate in batches( quiz_machine, args.nb_test_samples, data_structures, @@ -517,9 +523,7 @@ def run_ae_test( c_quizzes=c_quizzes, desc="test", ): - result = diffuser.ae_generate( - model, (1 - mask_generate) * x_0, mask_generate - ) + result = diffuser.generate(model, (1 - mask_generate) * x_0, mask_generate) correct = (result == x_0).min(dim=1).values.long() predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ :, :, 1 @@ -561,7 +565,7 @@ def run_ae_test( ###################################################################### -def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device): +def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device): model.train().to(local_device) optimizer_to(model.optimizer, local_device) @@ -569,7 +573,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi # scaler = torch.amp.GradScaler("cuda") - for x_0, mask_generate in ae_batches( + for x_0, mask_generate in batches( quiz_machine, args.nb_train_samples, data_structures, @@ -611,12 +615,12 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - model.test_accuracy = run_ae_test( + model.test_accuracy = run_test( model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device ) if args.nb_test_alien_samples > 0: - run_ae_test( + run_test( model, alien_quiz_machine, n_epoch, @@ -662,12 +666,15 @@ def quiz_validation( models, c_quizzes, local_device, - nb_have_to_be_correct=3, - nb_have_to_be_wrong=1, - nb_mistakes_to_be_wrong=5, + nb_have_to_be_correct, + nb_have_to_be_wrong, + nb_mistakes_to_be_wrong, nb_hints=0, nb_runs=1, ): + ###################################################################### + # If too many with process per-batch + if c_quizzes.size(0) > args.inference_batch_size: record = [] for q in c_quizzes.split(args.inference_batch_size): @@ -684,9 +691,12 @@ def quiz_validation( ) ) - return (torch.cat([tk for tk, _ in record], dim=0)), ( - torch.cat([w for _, w in record], dim=0) - ) + r = [] + for k in range(len(record[0])): + r.append(torch.cat([x[k] for x in record], dim=0)) + + return tuple(r) + ###################################################################### record_wrong = [] nb_correct, nb_wrong = 0, 0 @@ -704,22 +714,11 @@ def quiz_validation( sub_correct, sub_wrong = False, True for _ in range(nb_runs): - if nb_hints == 0: - mask_hints = None - else: - u = ( - torch.rand(mask_generate.size(), device=mask_generate.device) - * mask_generate - ) - mask_hints = ( - u > u.sort(dim=1, descending=True).values[:, nb_hints, None] - ).long() - - result = ae_generate( + result = diffuser.generate( model=model, x_0=c_quizzes, mask_generate=mask_generate, - mask_hints=mask_hints, + nb_hints=nb_hints, ) nb_mistakes = (result != c_quizzes).long().sum(dim=1) @@ -746,7 +745,7 @@ def quiz_validation( ###################################################################### -def generate_ae_c_quizzes(models, nb, local_device=main_device): +def generate_c_quizzes(models, nb, local_device=main_device): # To be thread-safe we must make copies def copy_for_inference(model): @@ -776,7 +775,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) ) - c_quizzes = ae_generate(model, template, mask_generate) + c_quizzes = diffuser.generate(model, template, mask_generate) to_keep = quiz_machine.problem.trivial(c_quizzes) == False c_quizzes = c_quizzes[to_keep] @@ -786,6 +785,9 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): models, c_quizzes, local_device, + nb_have_to_be_correct=args.nb_have_to_be_correct, + nb_have_to_be_wrong=args.nb_have_to_be_wrong, + nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong, nb_hints=args.nb_hints, nb_runs=args.nb_runs, ) @@ -839,21 +841,25 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False) c_quizzes = c_quizzes.to(main_device) with torch.autograd.no_grad(): + to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation( + models, + c_quizzes, + main_device, + nb_have_to_be_correct=args.nb_have_to_be_correct, + nb_have_to_be_wrong=0, + nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong, + nb_hints=0, + ) + if solvable_only: - to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation( - models, - c_quizzes, - main_device, - nb_have_to_be_correct=2, - nb_have_to_be_wrong=0, - nb_hints=0, - ) c_quizzes = c_quizzes[to_keep] + nb_correct = nb_correct[to_keep] + nb_wrong = nb_wrong[to_keep] - comments = [] + comments = [] - for c, w in zip(nb_correct, nb_wrong): - comments.append("nb_correct {c} nb_wrong {w}") + for c, w in zip(nb_correct, nb_wrong): + comments.append(f"nb_correct {c} nb_wrong {w}") quiz_machine.problem.save_quizzes_as_image( args.result_dir, @@ -922,7 +928,7 @@ if args.quizzes is not None: mask_generate = quiz_machine.make_quiz_mask( quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - result = ae_generate( + result = generate( model, (1 - mask_generate) * quizzes, mask_generate, @@ -968,10 +974,9 @@ def multithread_execution(fun, arguments): records.append(fun(*args)) for args in arguments: - t = threading.Thread(target=threadable_fun, daemon=True, args=args) - # To get a different sequence between threads log_string(f"dummy_rand {torch.rand(1)}") + t = threading.Thread(target=threadable_fun, daemon=True, args=args) threads.append(t) t.start() @@ -1039,7 +1044,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus c_quizzes, agreements = multithread_execution( - generate_ae_c_quizzes, + generate_c_quizzes, [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], ) @@ -1057,7 +1062,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): solvable_only=True, ) - u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, 1:] + u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:] i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices save_c_quizzes_with_scores( @@ -1085,7 +1090,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], multithread_execution( - one_ae_epoch, + one_epoch, [ (model, quiz_machine, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)