Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 10:40:21 +0000 (12:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 10:40:21 +0000 (12:40 +0200)
main.py

diff --git a/main.py b/main.py
index 19a3fee..49799e4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -59,7 +59,7 @@ parser.add_argument("--physical_batch_size", type=int, default=None)
 
 parser.add_argument("--inference_batch_size", type=int, default=25)
 
-parser.add_argument("--nb_train_samples", type=int, default=25000)
+parser.add_argument("--nb_train_samples", type=int, default=100000)
 
 parser.add_argument("--nb_test_samples", type=int, default=1000)
 
@@ -567,7 +567,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    scaler = torch.amp.GradScaler("cuda")
+    scaler = torch.amp.GradScaler("cuda")
 
     for x_0, mask_generate in ae_batches(
         quiz_machine,
@@ -595,17 +595,17 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         acc_train_loss += loss.item() * x_0.size(0)
         nb_train_samples += x_0.size(0)
 
-        loss.backward()
+        loss.backward()
 
-        if nb_train_samples % args.batch_size == 0:
-        # model.optimizer.step()
+        if nb_train_samples % args.batch_size == 0:
+            model.optimizer.step()
 
-        scaler.scale(loss).backward()
+        scaler.scale(loss).backward()
 
-        if nb_train_samples % args.batch_size == 0:
-            scaler.step(model.optimizer)
+        if nb_train_samples % args.batch_size == 0:
+        # scaler.step(model.optimizer)
 
-        scaler.update()
+        scaler.update()
 
     log_string(
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"