From: François Fleuret Date: Fri, 23 Aug 2024 05:16:55 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=7f02dc7ed22ea31a434f2f959bffa3443e6f459f;p=culture.git Update. --- diff --git a/main.py b/main.py index a65d893..1999bac 100755 --- a/main.py +++ b/main.py @@ -905,13 +905,6 @@ def test_ae(local_device=main_device): quiz_machine.problem.nb_colors, input.size(), device=input.device ) - L = input.size(1) // 4 - - input[:, 0 * L] = targets[:, 0 * L] - input[:, 1 * L] = targets[:, 1 * L] - input[:, 2 * L] = targets[:, 2 * L] - input[:, 3 * L] = targets[:, 3 * L] - output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_train_loss += loss.item() * input.size(0) @@ -955,13 +948,6 @@ def test_ae(local_device=main_device): quiz_machine.problem.nb_colors, input.size(), device=input.device ) - L = input.size(1) // 4 - - input[:, 0 * L] = targets[:, 0 * L] - input[:, 1 * L] = targets[:, 1 * L] - input[:, 2 * L] = targets[:, 2 * L] - input[:, 3 * L] = targets[:, 3 * L] - output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_test_loss += loss.item() * input.size(0) @@ -975,8 +961,9 @@ def test_ae(local_device=main_device): targets = input + input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA + pred_result = None - frozzen = None mask_noise = (mask_generate != 0) & ( torch.rand(mask_generate.size(), device=mask_generate.device) @@ -989,13 +976,6 @@ def test_ae(local_device=main_device): quiz_machine.problem.nb_colors, input.size(), device=input.device ) - L = input.size(1) // 4 - - result[:, 0 * L] = input[:, 0 * L] - result[:, 1 * L] = input[:, 1 * L] - result[:, 2 * L] = input[:, 2 * L] - result[:, 3 * L] = input[:, 3 * L] - i = torch.full((result.size(0),), True, device=result.device) nb_it = 0 @@ -1004,11 +984,9 @@ def test_ae(local_device=main_device): logits = model(mygpt.BracketedSequence(result)).x dist = torch.distributions.categorical.Categorical(logits=logits) pred_result = result.clone() - result[i] = dist.sample()[i] - result[:, 0 * L] = input[:, 0 * L] - result[:, 1 * L] = input[:, 1 * L] - result[:, 2 * L] = input[:, 2 * L] - result[:, 3 * L] = input[:, 3 * L] + result[i] = (1 - mask_generate) * input + ( + mask_generate * dist.sample()[i] + ) changed = (pred_result == result).long().min(dim=1).values == 0 i = i & changed nb_it += 1