Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 06:27:28 +0000 (08:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 06:27:28 +0000 (08:27 +0200)
main.py

diff --git a/main.py b/main.py
index cd78959..e7dd337 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -750,7 +750,7 @@ from mygpt import (
 )
 
 
-class MyAttentionVAE(nn.Module):
+class MyAttentionAE(nn.Module):
     def __init__(
         self,
         vocabulary_size,
@@ -849,7 +849,7 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
 
 
 def test_ae(local_device=main_device):
-    model = MyAttentionVAE(
+    model = MyAttentionAE(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
@@ -962,74 +962,71 @@ def test_ae(local_device=main_device):
             # -------------------------------------------
             # Test generation
 
-            input, mask_generate, mask_loss = next(
-                ae_batches(quiz_machine, 128, data_structures, local_device)
-            )
+            for ns, s in enumerate(data_structures):
+                quad_order, quad_generate, _, _ = s
 
-            targets = input
+                input, mask_generate, mask_loss = next(
+                    ae_batches(quiz_machine, 128, [s], local_device)
+                )
 
-            input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
+                targets = input
 
-            result = (1 - mask_generate) * input + mask_generate * torch.randint(
-                quiz_machine.problem.nb_colors, input.size(), device=input.device
-            )
+                input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
 
-            not_converged = torch.full((result.size(0),), True, device=result.device)
+                result = (1 - mask_generate) * input + mask_generate * torch.randint(
+                    quiz_machine.problem.nb_colors, input.size(), device=input.device
+                )
 
-            nb_it = 0
+                not_converged = torch.full(
+                    (result.size(0),), True, device=result.device
+                )
 
-            while True:
-                logits = model(mygpt.BracketedSequence(result)).x
-                dist = torch.distributions.categorical.Categorical(logits=logits)
-                pred_result = result.clone()
-                update = (1 - mask_generate) * input + mask_generate * dist.sample()
-                result[not_converged] = update[not_converged]
-                not_converged = (pred_result != result).max(dim=1).values
-                nb_it += 1
-                print("DEBUG", nb_it, not_converged.long().sum().item())
-                if not not_converged.any() or nb_it > 100:
-                    break
+                nb_it = 0
 
-            correct = (result == targets).min(dim=1).values.long()
-            predicted_parts = input.new(input.size(0), 4)
+                while True:
+                    logits = model(mygpt.BracketedSequence(result)).x
+                    dist = torch.distributions.categorical.Categorical(logits=logits)
+                    pred_result = result.clone()
+                    update = (1 - mask_generate) * input + mask_generate * dist.sample()
+                    result[not_converged] = update[not_converged]
+                    not_converged = (pred_result != result).max(dim=1).values
+                    nb_it += 1
+                    print("DEBUG", nb_it, not_converged.long().sum().item())
+                    if not not_converged.any() or nb_it > 100:
+                        break
 
-            nb = 0
+                correct = (result == targets).min(dim=1).values.long()
+                predicted_parts = input.new(input.size(0), 4)
 
-            # We consider all the configurations that we train for
-            for quad_order, quad_generate, _, _ in quiz_machine.test_structures:
-                i = quiz_machine.problem.indices_select(
-                    quizzes=input, quad_order=quad_order
-                )
-                nb += i.long().sum()
+                nb = 0
 
-                predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[
+                predicted_parts = torch.tensor(quad_generate, device=result.device)[
                     None, :
                 ]
-                solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
-                correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
-
-            assert nb == input.size(0)
+                solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
+                correct = (2 * correct - 1) * (solution_is_deterministic).long()
 
-            nb_correct = (correct == 1).long().sum()
-            nb_total = (correct != 0).long().sum()
+                nb_correct = (correct == 1).long().sum()
+                nb_total = (correct != 0).long().sum()
 
-            log_string(
-                f"test_accuracy {n_epoch} model AE {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
-            )
+                log_string(
+                    f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+                )
 
-            correct_parts = predicted_parts * correct[:, None]
+                correct_parts = predicted_parts * correct[:, None]
+                predicted_parts = predicted_parts.expand_as(correct_parts)
 
-            filename = f"prediction_ae_{n_epoch:04d}.png"
+                filename = f"prediction_ae_{n_epoch:04d}_{ns}.png"
 
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                filename,
-                quizzes=result,
-                predicted_parts=predicted_parts,
-                correct_parts=correct_parts,
-            )
+                quiz_machine.problem.save_quizzes_as_image(
+                    args.result_dir,
+                    filename,
+                    quizzes=result,
+                    predicted_parts=predicted_parts,
+                    correct_parts=correct_parts,
+                )
 
-            log_string(f"wrote {filename}")
+                log_string(f"wrote {filename}")
 
 
 if args.test == "ae":