Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 19:15:37 +0000 (21:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 19:15:37 +0000 (21:15 +0200)
attae.py
main.py

index 9a2f240..06deed2 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -51,8 +51,6 @@ def attention(q, k, v):
     return y
 
 
-attention = torch.compile(attention)
-
 ######################################################################
 
 
diff --git a/main.py b/main.py
index 8010fa4..62cbd2f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -121,8 +121,6 @@ parser.add_argument("--nb_hints", type=int, default=25)
 
 parser.add_argument("--nb_runs", type=int, default=1)
 
-parser.add_argument("--dirty_debug", action="store_true", default=False)
-
 parser.add_argument("--test", type=str, default=None)
 
 parser.add_argument("--quizzes", type=str, default=None)
@@ -210,7 +208,7 @@ else:
 
 if args.resume:
     if not os.path.isdir(args.result_dir):
-        print(f"Trying to resume with a non-existing result dir {args.result_dir}.")
+        print(f"Trying to resume from a non-existing result dir {args.result_dir}.")
         exit(1)
 else:
     try:
@@ -276,10 +274,6 @@ else:
     assert len(gpus) == 0
     main_device = torch.device("cpu")
 
-if args.dirty_debug:
-    args.nb_train_samples = 2500
-    args.nb_test_samples = 100
-
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
 else:
@@ -720,7 +714,7 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 
     x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
 
-    with torch.cuda.amp.autocast():
+    with torch.amp.autocast("cuda"):
         logits_hat_x_0 = model(x_t_with_mask)
 
     return logits_hat_x_0
@@ -745,7 +739,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None
 
     for it in range(nb_iterations_max):
         x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-        with torch.cuda.amp.autocast():
+        with torch.amp.autocast("cuda"):
             logits = model(x_t_with_mask)
         logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
         dist = torch.distributions.categorical.Categorical(logits=logits)
@@ -894,7 +888,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    scaler = torch.cuda.amp.GradScaler()
+    scaler = torch.amp.GradScaler("cuda")
 
     for x_0, mask_generate in ae_batches(
         quiz_machine,
@@ -910,7 +904,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        with torch.cuda.amp.autocast():
+        with torch.amp.autocast("cuda"):
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise
             )
@@ -963,6 +957,8 @@ for i in range(args.nb_models):
         dropout=args.dropout,
     ).to(main_device)
 
+    model = torch.compile(model)
+
     model.id = i
     model.test_accuracy = 0.0
     model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -1333,6 +1329,8 @@ def save_models(models, suffix=""):
 ######################################################################
 
 for n_epoch in range(current_epoch, args.nb_epochs):
+    start_time = time.perf_counter()
+
     state = {
         "current_epoch": n_epoch,
         "c_quizzes": c_quizzes,
@@ -1349,12 +1347,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # --------------------------------------------------------------------
 
-    log_string(f"{time_train=} {time_c_quizzes=}")
-
-    if (
-        min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes
-        and time_train >= time_c_quizzes
-    ):
+    if min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes:
         if c_quizzes is None:
             save_models(models, "naive")
 
@@ -1362,11 +1355,16 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         nb_gpus = len(gpus)
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
-        start_time = time.perf_counter()
+        args = [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus]
+
+        # Ugly hack: Only one thread during the first epoch so that
+        # compilation of the model does not explode
+        if n_epoch == 0:
+            args = args[:1]
 
         c_quizzes, agreements = multithread_execution(
             generate_ae_c_quizzes,
-            [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+            args,
         )
 
         save_c_quizzes_with_scores(
@@ -1385,8 +1383,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         log_string(f"generated_c_quizzes {c_quizzes.size()=}")
 
-        time_train = 0
-
         for model in models:
             model.test_accuracy = 0
 
@@ -1400,8 +1396,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
-    start_time = time.perf_counter()
-
     # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
 
     multithread_execution(
@@ -1412,8 +1406,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         ],
     )
 
-    time_train += int(time.perf_counter() - start_time)
-
     # --------------------------------------------------------------------
 
     save_models(models)