From d9af6a35001df7f478a8a1cd46f0d0359b0b0065 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 18:35:53 +0200 Subject: [PATCH] Update. --- main.py | 151 ++++++++++++++++---------------------------------------- 1 file changed, 42 insertions(+), 109 deletions(-) diff --git a/main.py b/main.py index c6d76ee..2f5e0dc 100755 --- a/main.py +++ b/main.py @@ -863,6 +863,32 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None): ) +def degrade_input_inplace(input, mask_generate, pure_noise=False): + if pure_noise: + mask_diffusion_noise = torch.rand( + mask_generate.size(), device=mask_generate.device + ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) + + mask_diffusion_noise = mask_diffusion_noise.long() + + input[...] = ( + 1 - mask_generate + ) * input + mask_generate * mask_diffusion_noise * torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + else: + model.eval() + for it in range(torch.randint(5, (1,)).item()): + logits = model( + mygpt.BracketedSequence( + torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) + ) + ).x + dist = torch.distributions.categorical.Categorical(logits=logits) + input[...] = (1 - mask_generate) * input + mask_generate * dist.sample() + model.train() + + def test_ae(local_device=main_device): model = MyAttentionAE( vocabulary_size=vocabulary_size, @@ -876,6 +902,7 @@ def test_ae(local_device=main_device): pure_noise = True + # quad_order, quad_generate, quad_noise, quad_loss data_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), @@ -909,40 +936,18 @@ def test_ae(local_device=main_device): if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - targets = input - - input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - - if pure_noise: - mask_diffusion_noise = torch.rand( - mask_generate.size(), device=mask_generate.device - ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) - - mask_diffusion_noise = mask_diffusion_noise.long() - - input = input + mask_generate * mask_diffusion_noise * torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) - else: - model.eval() - for it in range(torch.randint(5, (1,)).item()): - logits = model( - mygpt.BracketedSequence( - torch.cat( - [input[:, :, None], mask_generate[:, :, None]], dim=2 - ) - ) - ).x - dist = torch.distributions.categorical.Categorical(logits=logits) - input = (1 - mask_generate) * input + mask_generate * dist.sample() - model.train() + targets = input.clone() + degrade_input_inplace(input, mask_generate, pure_noise=pure_noise) output = model( mygpt.BracketedSequence( torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) ) ).x - loss = F.cross_entropy(output.transpose(1, 2), targets) + loss_per_token = F.cross_entropy( + output.transpose(1, 2), targets, reduction="none" + ) + loss = (loss_per_token * mask_loss).mean() acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) loss.backward() @@ -969,51 +974,17 @@ def test_ae(local_device=main_device): local_device, "test", ): - targets = input - - input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - - if pure_noise: - mask_diffusion_noise = torch.rand( - mask_generate.size(), device=mask_generate.device - ) <= torch.rand( - (mask_generate.size(0), 1), device=mask_generate.device - ) - - mask_diffusion_noise = mask_diffusion_noise.long() - - input = ( - input - + mask_generate - * mask_diffusion_noise - * torch.randint( - quiz_machine.problem.nb_colors, - input.size(), - device=input.device, - ) - ) - else: - for it in range(torch.randint(5, (1,)).item()): - logits = model( - mygpt.BracketedSequence( - torch.cat( - [input[:, None], mask_generate[:, None]], dim=1 - ) - ) - ).x - dist = torch.distributions.categorical.Categorical( - logits=logits - ) - input = ( - 1 - mask_generate - ) * input + mask_generate * dist.sample() - + targets = input.clone() + degrade_input_inplace(input, mask_generate, pure_noise=pure_noise) output = model( mygpt.BracketedSequence( torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) ) ).x - loss = F.cross_entropy(output.transpose(1, 2), targets) + loss_per_token = F.cross_entropy( + output.transpose(1, 2), targets, reduction="none" + ) + loss = (loss_per_token * mask_loss).mean() acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -1029,46 +1000,8 @@ def test_ae(local_device=main_device): ae_batches(quiz_machine, 128, [s], local_device) ) - targets = input - - input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - - if pure_noise: - mask_diffusion_noise = torch.rand( - mask_generate.size(), device=mask_generate.device - ) <= torch.rand( - (mask_generate.size(0), 1), device=mask_generate.device - ) - - mask_diffusion_noise = mask_diffusion_noise.long() - - input = ( - input - + mask_generate - * mask_diffusion_noise - * torch.randint( - quiz_machine.problem.nb_colors, - input.size(), - device=input.device, - ) - ) - else: - for it in range(torch.randint(5, (1,)).item()): - logits = model( - mygpt.BracketedSequence( - torch.cat( - [input[:, :, None], mask_generate[:, :, None]], - dim=2, - ) - ) - ).x - dist = torch.distributions.categorical.Categorical( - logits=logits - ) - input = ( - 1 - mask_generate - ) * input + mask_generate * dist.sample() - + targets = input.clone() + degrade_input_inplace(input, mask_generate, pure_noise=pure_noise) result = input not_converged = torch.full( @@ -1082,7 +1015,7 @@ def test_ae(local_device=main_device): torch.cat( [ result[not_converged, :, None], - mask_generate[:, :, None], + mask_generate[not_converged, :, None], ], dim=2, ) -- 2.39.5