From 218d97bf16bc705279570ae9085c5752fee71e91 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 08:27:28 +0200 Subject: [PATCH] Update. --- main.py | 101 +++++++++++++++++++++++++++----------------------------- 1 file changed, 49 insertions(+), 52 deletions(-) diff --git a/main.py b/main.py index cd78959..e7dd337 100755 --- a/main.py +++ b/main.py @@ -750,7 +750,7 @@ from mygpt import ( ) -class MyAttentionVAE(nn.Module): +class MyAttentionAE(nn.Module): def __init__( self, vocabulary_size, @@ -849,7 +849,7 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None): def test_ae(local_device=main_device): - model = MyAttentionVAE( + model = MyAttentionAE( vocabulary_size=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, @@ -962,74 +962,71 @@ def test_ae(local_device=main_device): # ------------------------------------------- # Test generation - input, mask_generate, mask_loss = next( - ae_batches(quiz_machine, 128, data_structures, local_device) - ) + for ns, s in enumerate(data_structures): + quad_order, quad_generate, _, _ = s - targets = input + input, mask_generate, mask_loss = next( + ae_batches(quiz_machine, 128, [s], local_device) + ) - input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA + targets = input - result = (1 - mask_generate) * input + mask_generate * torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) + input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - not_converged = torch.full((result.size(0),), True, device=result.device) + result = (1 - mask_generate) * input + mask_generate * torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) - nb_it = 0 + not_converged = torch.full( + (result.size(0),), True, device=result.device + ) - while True: - logits = model(mygpt.BracketedSequence(result)).x - dist = torch.distributions.categorical.Categorical(logits=logits) - pred_result = result.clone() - update = (1 - mask_generate) * input + mask_generate * dist.sample() - result[not_converged] = update[not_converged] - not_converged = (pred_result != result).max(dim=1).values - nb_it += 1 - print("DEBUG", nb_it, not_converged.long().sum().item()) - if not not_converged.any() or nb_it > 100: - break + nb_it = 0 - correct = (result == targets).min(dim=1).values.long() - predicted_parts = input.new(input.size(0), 4) + while True: + logits = model(mygpt.BracketedSequence(result)).x + dist = torch.distributions.categorical.Categorical(logits=logits) + pred_result = result.clone() + update = (1 - mask_generate) * input + mask_generate * dist.sample() + result[not_converged] = update[not_converged] + not_converged = (pred_result != result).max(dim=1).values + nb_it += 1 + print("DEBUG", nb_it, not_converged.long().sum().item()) + if not not_converged.any() or nb_it > 100: + break - nb = 0 + correct = (result == targets).min(dim=1).values.long() + predicted_parts = input.new(input.size(0), 4) - # We consider all the configurations that we train for - for quad_order, quad_generate, _, _ in quiz_machine.test_structures: - i = quiz_machine.problem.indices_select( - quizzes=input, quad_order=quad_order - ) - nb += i.long().sum() + nb = 0 - predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[ + predicted_parts = torch.tensor(quad_generate, device=result.device)[ None, : ] - solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 - correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() - - assert nb == input.size(0) + solution_is_deterministic = predicted_parts.sum(dim=-1) == 1 + correct = (2 * correct - 1) * (solution_is_deterministic).long() - nb_correct = (correct == 1).long().sum() - nb_total = (correct != 0).long().sum() + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() - log_string( - f"test_accuracy {n_epoch} model AE {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" - ) + log_string( + f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" + ) - correct_parts = predicted_parts * correct[:, None] + correct_parts = predicted_parts * correct[:, None] + predicted_parts = predicted_parts.expand_as(correct_parts) - filename = f"prediction_ae_{n_epoch:04d}.png" + filename = f"prediction_ae_{n_epoch:04d}_{ns}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - ) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, + ) - log_string(f"wrote {filename}") + log_string(f"wrote {filename}") if args.test == "ae": -- 2.39.5