From f32e85048904ff54d731054f8c0c3d8e47b9017d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 07:29:48 +0200 Subject: [PATCH] Update. --- main.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 0fe33f6..2f867db 100755 --- a/main.py +++ b/main.py @@ -969,12 +969,11 @@ def test_ae(local_device=main_device): targets = input input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - pred_result = None result = (1 - mask_generate) * input + mask_generate * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) - i = torch.full((result.size(0),), True, device=result.device) + not_converged = torch.full((result.size(0),), True, device=result.device) nb_it = 0 @@ -982,11 +981,10 @@ def test_ae(local_device=main_device): logits = model(mygpt.BracketedSequence(result)).x dist = torch.distributions.categorical.Categorical(logits=logits) pred_result = result.clone() - result[i] = (1 - mask_generate[i]) * input + ( - mask_generate * dist.sample()[i] - ) - changed = (pred_result == result).long().min(dim=1).values == 0 - i = i & changed + result[not_converged] = ( + (1 - mask_generate) * input + mask_generate * dist.sample() + )[not_converged] + not_converged = (pred_result == result).long().min(dim=1).values == 0 nb_it += 1 print("DEBUG", nb_it, i.long().sum().item()) if not i.any() or nb_it > 100: -- 2.39.5