Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 17 Aug 2024 20:49:53 +0000 (22:49 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 17 Aug 2024 20:49:53 +0000 (22:49 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 92bc05f..127b71b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -115,6 +115,8 @@ parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 parser.add_argument("--test", type=str, default=None)
 
+parser.add_argument("--logit_std_max", type=float, default=-1)
+
 ######################################################################
 
 grids_tasks = ", ".join(
@@ -820,6 +822,21 @@ for k in range(args.nb_gpts):
         dropout=args.dropout,
     ).to(main_device)
 
+    class UpperBoundStd(nn.Module):
+        def __init__(self, std_max=1.0):
+            super().__init__()
+            self.std_max = std_max
+
+        def forward(self, x):
+            std = x.std(dim=-1, keepdim=True)
+            y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
+            return y
+
+    if args.logit_std_max > 0:
+        model.readout.f = nn.Sequential(
+            model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
+        )
+
     model.id = k
     model.train_c_quiz_bags = []
     model.test_c_quiz_bags = []
@@ -1034,36 +1051,57 @@ def save_generated_c_quizzes(model, filename, nb=64):
 
 ######################################################################
 
+
 if args.test == "entropy":
     model = models[0]
     model.to(main_device)
 
-    log_string("starting testing entropy maximization")
-
-    train_input = quiz_machine.generate_c_quizzes(
-        1000,
-        model_for_generation=model,
-        procedure=c_quizzes_procedure,
-    )
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
 
-    for n_epoch in range(10):
-        nb_train_samples, acc_train_loss = 0, 0.0
+    log_string("starting testing entropy maximization")
 
-        for input in train_input.split(args.batch_size):
-            input = input.to(main_device)
-            output = model(mygpt.BracketedSequence(input)).x
-            loss = output.log_softmax(dim=1).mean()
+    for n_epoch in range(100):
+        input = quiz_machine.generate_c_quizzes(
+            128,
+            model_for_generation=model,
+            procedure=c_quizzes_procedure,
+        )
 
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            f"test_{n_epoch:04d}.png",
+            quizzes=input,
+        )
 
-            model.optimizer.zero_grad()
-            loss.backward()
-            model.optimizer.step()
+        log_string(f"wrote {filename}")
 
-        log_string(
-            f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
-        )
+        with torch.no_grad():
+            for p in model.parameters():
+                p += torch.randn(p.size(), device=p.device) * 1e-3
+
+        # nb_train_samples, acc_train_loss = 0, 0.0
+
+        # for k in range(1000 // args.batch_size):
+        # input = quiz_machine.generate_c_quizzes(
+        # args.batch_size,
+        # model_for_generation=model,
+        # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)],
+        # )
+
+        # input = input.to(main_device)
+        # targets = input
+        # output = model(mygpt.BracketedSequence(input)).x
+        # loss = -F.cross_entropy(output.transpose(1, 2), targets)
+        # acc_train_loss += loss.item() * input.size(0)
+        # nb_train_samples += input.size(0)
+
+        # optimizer.zero_grad()
+        # loss.backward()
+        # optimizer.step()
+
+        # log_string(
+        # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
+        # )
 
     exit(0)
 
index 98e0ea5..18136e8 100755 (executable)
@@ -355,7 +355,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_quiz_mask(c_quizzes, s, m),
                 seq_logprobas=seq_logprobas,
-                progress_bar_desc=f"autoregression {n_step}/{len(procedure)}",
+                progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
             )
 
             model_for_generation.reset_transformations()