Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 20:23:32 +0000 (22:23 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 20:23:32 +0000 (22:23 +0200)
main.py

diff --git a/main.py b/main.py
index b6c62cf..f28c10a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -871,7 +871,12 @@ def test_ae(local_device=main_device):
                 model.optimizer.zero_grad()
 
             targets = input
-            input = (mask_generate == 0).long() * input
+
+            input = (mask_generate == 0).long() * input + (
+                1 - (mask_generate == 0).long()
+            ) * torch.randint(
+                quiz_machine.problem.nb_colors, input.size(), device=input.device
+            )
 
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), targets)
@@ -915,7 +920,13 @@ def test_ae(local_device=main_device):
                 mask_loss = mask_loss.to(local_device)
 
                 targets = input
-                input = (mask_generate == 0).long() * input
+
+                input = (mask_generate == 0).long() * input + (
+                    1 - (mask_generate == 0).long()
+                ) * torch.randint(
+                    quiz_machine.problem.nb_colors, input.size(), device=input.device
+                )
+
                 output = model(mygpt.BracketedSequence(input)).x
                 loss = F.cross_entropy(output.transpose(1, 2), targets)
                 acc_test_loss += loss.item() * input.size(0)
@@ -928,15 +939,38 @@ def test_ae(local_device=main_device):
             mask_generate = mask_generate.to(local_device)
             mask_loss = mask_loss.to(local_device)
             targets = input
-            input = (mask_generate == 0).long() * input
-            logits = model(mygpt.BracketedSequence(input)).x
-            dist = torch.distributions.categorical.Categorical(logits=logits)
-            result = dist.sample()
+
+            pred_result = None
+            frozzen = None
+
+            result = (mask_generate == 0).long() * input + (
+                1 - (mask_generate == 0).long()
+            ) * torch.randint(
+                quiz_machine.problem.nb_colors, input.size(), device=input.device
+            )
+
+            i = torch.full((result.size(0),), True, device=result.device)
+
+            nb_it = 0
+
             L = input.size(1) // 4
-            result[:, 0 * L] = input[:, 0 * L]
-            result[:, 1 * L] = input[:, 1 * L]
-            result[:, 2 * L] = input[:, 2 * L]
-            result[:, 3 * L] = input[:, 3 * L]
+
+            while True:
+                logits = model(mygpt.BracketedSequence(result)).x
+                dist = torch.distributions.categorical.Categorical(logits=logits)
+                pred_result = result.clone()
+                result[i] = dist.sample()[i]
+                result[:, 0 * L] = input[:, 0 * L]
+                result[:, 1 * L] = input[:, 1 * L]
+                result[:, 2 * L] = input[:, 2 * L]
+                result[:, 3 * L] = input[:, 3 * L]
+                changed = (pred_result == result).long().min(dim=1).values == 0
+                i = i & changed
+                nb_it += 1
+                print("DEBUG", nb_it, i.long().sum().item())
+                if not i.any() or nb_it > 100:
+                    break
+
             correct = (result == targets).min(dim=1).values.long()
             predicted_parts = input.new(input.size(0), 4)
 
@@ -958,8 +992,8 @@ def test_ae(local_device=main_device):
             nb_correct = (correct == 1).long().sum()
             nb_total = (correct != 0).long().sum()
 
-            self.logger(
-                f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+            log_string(
+                f"test_accuracy {n_epoch} model AE {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
             )
 
             correct_parts = predicted_parts * correct[:, None]