From ab88d3787fe47e4c52c6a3c78bb7354294fd6752 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 17 Sep 2024 19:51:36 +0200 Subject: [PATCH] Update. --- main.py | 93 +++++++++++++++++++++++++++++++++++-------------- quiz_machine.py | 2 ++ 2 files changed, 68 insertions(+), 27 deletions(-) diff --git a/main.py b/main.py index 51e0fa2..9525bdd 100755 --- a/main.py +++ b/main.py @@ -327,10 +327,6 @@ quiz_machine = quiz_machine.QuizMachine( ) -def mu_T_sampler(shape, device="cpu"): - return torch.randint(quiz_machine.problem.nb_colors, shape, device=device) - - diffuser = diffusion.Diffuser( mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption ) @@ -397,22 +393,27 @@ def masked_cross_entropy(output, targets, masks): ###################################################################### + +def add_hints(masks, fraction_with_hints): + if fraction_with_hints > 0: + h = torch.rand(masks.size(), device=masks.device) * masks + mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints + v = torch.rand(masks.size(0), device=masks.device)[:, None] + mask_hints = mask_hints * (v < fraction_with_hints).long() + return (1 - mask_hints) * masks + else: + return masks + + # IMT for input / masks / target -def IMT_batch_prediction(input, proba_hints=0.0): +def batch_prediction_imt(input, fraction_with_hints=0.0): nb = input.size(0) masks = input.new_zeros(input.size()) u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] - - if proba_hints > 0: - h = torch.rand(input.size(), device=input.device) * masks - mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints - v = torch.rand(nb, device=input.device)[:, None] - mask_hints = mask_hints * (v < proba_hints).long() - masks = (1 - mask_hints) * masks - + masks = add_hints(masks, fraction_with_hints) # noise = quiz_machine.problem.pure_noise(nb, input.device) targets = input input = (1 - masks) * targets # + masks * noise @@ -444,10 +445,32 @@ def predict(model, imt_set, local_device=main_device): return torch.cat(record) +def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device): + boy_that_s_ugly = input.view(input.size(0), 4, -1)[:, :, 0].clone() + input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) + nb = input.size(0) + masks = input.new_zeros(input.size()) + u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4) + masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] + masks_with_hints = add_hints(masks, fraction_with_hints) + targets = input + input = (1 - masks_with_hints) * targets + imt_set = torch.cat( + [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1 + ) + + result = predict(model, imt_set, local_device=local_device) + result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) + + result.view(result.size(0), 4, -1)[:, :, 0] = boy_that_s_ugly + + return result + + ###################################################################### -def IMT_batch_generation(input): +def batch_generation_imt(input): nb = input.size(0) probs_iterations = 0.1 ** torch.linspace( 0, 1, args.diffusion_nb_iterations, device=input.device @@ -516,16 +539,6 @@ def generate(model, nb, local_device=main_device): def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): - if train: - label = "train" - model.train().to(local_device) - optimizer_to(model.optimizer, local_device) - else: - label = "test" - model.eval().to(local_device) - - nb_samples, acc_loss = 0, 0.0 - quizzes = quiz_machine.quiz_set( args.nb_train_samples if train else args.nb_test_samples, c_quizzes, @@ -535,11 +548,24 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): q1, q2 = quizzes.to(local_device).chunk(2) imt_set = torch.cat( - [IMT_batch_prediction(q1, proba_hints=0.5), IMT_batch_generation(q2)] + [ + batch_prediction_imt(q1, fraction_with_hints=0.5), + batch_generation_imt(q2), + ] ) imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] + if train: + label = "train" + model.train().to(local_device) + optimizer_to(model.optimizer, local_device) + else: + label = "test" + model.eval().to(local_device) + + nb_samples, acc_loss = 0, 0.0 + for imt in tqdm.tqdm( imt_set.split(args.physical_batch_size), dynamic_ncols=True, @@ -574,10 +600,23 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True) one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False) + #!!!!!!!!!!!!!!!!!!!!!!!!! + quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to( + local_device + ) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes + ) + result = predict_full(model, quizzes, local_device=local_device) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result + ) + #!!!!!!!!!!!!!!!!!!!!!!!!! + # predict quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier) - imt_set = IMT_batch_prediction(quizzes.to(local_device)) + imt_set = batch_prediction_imt(quizzes.to(local_device)) result = predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") @@ -638,7 +677,7 @@ for i in range(args.nb_models): ###################################################################### -def quiz_validation( +def quiz_validation_( models, c_quizzes, local_device, diff --git a/quiz_machine.py b/quiz_machine.py index 781c1cf..dfedbf5 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -222,6 +222,8 @@ class QuizMachine: i = torch.randperm(quizzes.size(0), device=quizzes.device) quizzes = quizzes[i].contiguous() + quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].contiguous() + return quizzes ###################################################################### -- 2.39.5