From feefa5ca1a794b55307fd4c7668a66bcd2b447e3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 08:56:35 +0200 Subject: [PATCH] Update. --- main.py | 90 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 55 insertions(+), 35 deletions(-) diff --git a/main.py b/main.py index e7dd337..289bae4 100755 --- a/main.py +++ b/main.py @@ -894,18 +894,28 @@ def test_ae(local_device=main_device): targets = input - mask_diffusion_noise = (mask_generate == 1) & ( - torch.rand(mask_generate.size(), device=mask_generate.device) - <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) - ) + # mask_diffusion_noise = (mask_generate == 1) & ( + # torch.rand(mask_generate.size(), device=mask_generate.device) + # <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) + # ) - mask_diffusion_noise = mask_diffusion_noise.long() + # mask_diffusion_noise = mask_diffusion_noise.long() - input = ( - 1 - mask_diffusion_noise - ) * input + mask_diffusion_noise * torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) + # input = ( + # 1 - mask_diffusion_noise + # ) * input + mask_diffusion_noise * torch.randint( + # quiz_machine.problem.nb_colors, input.size(), device=input.device + # ) + + # ------------------------------ + input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA + model.eval() + for it in range(torch.randint(5, (1,)).item()): + logits = model(mygpt.BracketedSequence(input)).x + dist = torch.distributions.categorical.Categorical(logits=logits) + input = (1 - mask_generate) * input + mask_generate * dist.sample() + model.train() + # ----------------------------- output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) @@ -937,20 +947,29 @@ def test_ae(local_device=main_device): ): targets = input - mask_diffusion_noise = (mask_generate == 1) & ( - torch.rand(mask_generate.size(), device=mask_generate.device) - <= torch.rand( - (mask_generate.size(0), 1), device=mask_generate.device - ) - ) + # mask_diffusion_noise = (mask_generate == 1) & ( + # torch.rand(mask_generate.size(), device=mask_generate.device) + # <= torch.rand( + # (mask_generate.size(0), 1), device=mask_generate.device + # ) + # ) - mask_diffusion_noise = mask_diffusion_noise.long() + # mask_diffusion_noise = mask_diffusion_noise.long() - input = ( - 1 - mask_diffusion_noise - ) * input + mask_diffusion_noise * torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) + # input = ( + # 1 - mask_diffusion_noise + # ) * input + mask_diffusion_noise * torch.randint( + # quiz_machine.problem.nb_colors, input.size(), device=input.device + # ) + + # ------------------------------ + input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA + + for it in range(torch.randint(5, (1,)).item()): + logits = model(mygpt.BracketedSequence(input)).x + dist = torch.distributions.categorical.Categorical(logits=logits) + input = (1 - mask_generate) * input + mask_generate * dist.sample() + # ----------------------------- output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) @@ -973,26 +992,27 @@ def test_ae(local_device=main_device): input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - result = (1 - mask_generate) * input + mask_generate * torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) + result = (1 - mask_generate) * input + + # + mask_generate * torch.randint( + # quiz_machine.problem.nb_colors, input.size(), device=input.device + # ) not_converged = torch.full( (result.size(0),), True, device=result.device ) - nb_it = 0 - - while True: - logits = model(mygpt.BracketedSequence(result)).x - dist = torch.distributions.categorical.Categorical(logits=logits) + for it in range(100): pred_result = result.clone() - update = (1 - mask_generate) * input + mask_generate * dist.sample() - result[not_converged] = update[not_converged] + logits = model(mygpt.BracketedSequence(result[not_converged])).x + dist = torch.distributions.categorical.Categorical(logits=logits) + update = (1 - mask_generate[not_converged]) * input[ + not_converged + ] + mask_generate[not_converged] * dist.sample() + result[not_converged] = update not_converged = (pred_result != result).max(dim=1).values - nb_it += 1 - print("DEBUG", nb_it, not_converged.long().sum().item()) - if not not_converged.any() or nb_it > 100: + if not not_converged.any(): + log_string(f"diffusion_converged {it=}") break correct = (result == targets).min(dim=1).values.long() -- 2.39.5