From cbf1356e55cea67f49b5cdce35033b3df81ed3f9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 22:52:01 +0200 Subject: [PATCH] Update. --- main.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 2f5e0dc..7ce9b03 100755 --- a/main.py +++ b/main.py @@ -872,10 +872,14 @@ def degrade_input_inplace(input, mask_generate, pure_noise=False): mask_diffusion_noise = mask_diffusion_noise.long() input[...] = ( - 1 - mask_generate - ) * input + mask_generate * mask_diffusion_noise * torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device + mask_generate + * mask_diffusion_noise + * torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + + (1 - mask_generate * mask_diffusion_noise) * input ) + else: model.eval() for it in range(torch.randint(5, (1,)).item()): @@ -944,6 +948,20 @@ def test_ae(local_device=main_device): torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) ) ).x + + # for filename, quizzes in [ + # ("targets.png", targets), + # ("input.png", input), + # ("mask_generate.png", mask_generate), + # ("mask_loss.png", mask_loss), + # ]: + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename, + # quizzes=quizzes, + # ) + # time.sleep(10000) + loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) -- 2.39.5