Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 15:36:30 +0000 (17:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 15:36:30 +0000 (17:36 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 8715711..a8ceac8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -32,6 +32,8 @@ parser.add_argument("--result_dir", type=str, default=None)
 
 parser.add_argument("--seed", type=int, default=0)
 
+parser.add_argument("--resume", action="store_true", default=False)
+
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 
 ########################################
@@ -190,11 +192,15 @@ else:
 
 ######################################################################
 
-try:
-    os.mkdir(args.result_dir)
-except FileExistsError:
-    print(f"result directory {args.result_dir} already exists")
-    exit(1)
+if args.resume:
+    assert os.path.isdir(args.result_dir)
+
+else:
+    try:
+        os.mkdir(args.result_dir)
+    except FileExistsError:
+        print(f"result directory {args.result_dir} already exists")
+        exit(1)
 
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
@@ -437,8 +443,7 @@ def create_c_quizzes(
     quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
-    # save a bunch of images to investigate what quizzes with a
-    # certain nb of correct predictions look like
+    # save images
 
     q = new_c_quizzes[:72]
 
@@ -450,8 +455,6 @@ def create_c_quizzes(
 
 ######################################################################
 
-nb_loaded_models = 0
-
 models = []
 
 for k in range(args.nb_gpts):
@@ -475,23 +478,37 @@ for k in range(args.nb_gpts):
     model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
     quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
 
-    filename = f"gpt_{model.id:03d}.pth"
+    models.append(model)
+
+######################################################################
 
+if args.resume:
     try:
-        model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
-        log_string(f"model {model.id} successfully loaded from checkpoint.")
-        nb_loaded_models += 1
-
-    except FileNotFoundError:
-        log_string(f"starting model {model.id} from scratch.")
+        for model in models:
+            filename = f"gpt_{model.id:03d}.pth"
+
+            try:
+                model.load_state_dict(
+                    torch.load(os.path.join(args.result_dir, filename))
+                )
+                log_string(f"successfully loaded {filename}")
+            except FileNotFoundError:
+                log_string(f"cannot find {filename}")
+                pass
+
+        try:
+            filename = "c_quizzes.pth"
+            quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+            log_string(f"successfully loaded {filename}")
+        except FileNotFoundError:
+            log_string(f"cannot find {filename}")
+            pass
 
     except:
         log_string(f"error when loading {filename}.")
         exit(1)
 
-    models.append(model)
-
-assert nb_loaded_models == 0 or nb_loaded_models == len(models)
+######################################################################
 
 nb_parameters = sum(p.numel() for p in models[0].parameters())
 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
@@ -600,6 +617,7 @@ for n_epoch in range(args.nb_epochs):
     for model in weakest_models:
         filename = f"gpt_{model.id:03d}.pth"
         torch.save(model.state_dict(), os.path.join(args.result_dir, filename))
+        log_string(f"wrote {filename}")
 
     ##################################################
     # Replace a fraction of the w_quizzes with fresh ones
@@ -625,4 +643,6 @@ for n_epoch in range(args.nb_epochs):
             nb_for_test=nb_new_c_quizzes_for_test,
         )
 
+        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, "c_quizzes.pth"))
+
 ######################################################################
index 88fd9f1..c39bf7a 100755 (executable)
@@ -412,6 +412,12 @@ class QuizMachine:
             else:
                 self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
 
+    def save_c_quizzes(self, filename):
+        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+    def load_c_quizzes(self, filename):
+        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
     ######################################################################
 
     def logproba_of_solutions(self, models, c_quizzes):