From: François Fleuret Date: Fri, 23 Aug 2024 05:29:48 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f32e85048904ff54d731054f8c0c3d8e47b9017d;p=culture.git Update. --- diff --git a/main.py b/main.py index 0fe33f6..2f867db 100755 --- a/main.py +++ b/main.py @@ -969,12 +969,11 @@ def test_ae(local_device=main_device): targets = input input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - pred_result = None result = (1 - mask_generate) * input + mask_generate * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) - i = torch.full((result.size(0),), True, device=result.device) + not_converged = torch.full((result.size(0),), True, device=result.device) nb_it = 0 @@ -982,11 +981,10 @@ def test_ae(local_device=main_device): logits = model(mygpt.BracketedSequence(result)).x dist = torch.distributions.categorical.Categorical(logits=logits) pred_result = result.clone() - result[i] = (1 - mask_generate[i]) * input + ( - mask_generate * dist.sample()[i] - ) - changed = (pred_result == result).long().min(dim=1).values == 0 - i = i & changed + result[not_converged] = ( + (1 - mask_generate) * input + mask_generate * dist.sample() + )[not_converged] + not_converged = (pred_result == result).long().min(dim=1).values == 0 nb_it += 1 print("DEBUG", nb_it, i.long().sum().item()) if not i.any() or nb_it > 100: