Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 10:18:18 +0000 (12:18 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 10:18:18 +0000 (12:18 +0200)
diffusion.py
main.py

index 98d8d0a..8c6e08d 100755 (executable)
@@ -92,8 +92,8 @@ class Diffuser:
 
         x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
 
-        # with torch.amp.autocast("cuda"):
-        logits_hat_x_0 = model(x_t_with_mask)
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            logits_hat_x_0 = model(x_t_with_mask)
 
         return logits_hat_x_0
 
@@ -117,8 +117,8 @@ class Diffuser:
 
         for it in range(self.nb_iterations):
             x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-            # with torch.amp.autocast("cuda"):
-            logits = model(x_t_with_mask)
+            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+                logits = model(x_t_with_mask)
             # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
             dist = torch.distributions.categorical.Categorical(logits=logits)
 
diff --git a/main.py b/main.py
index 534bab9..19a3fee 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -27,6 +27,8 @@ import threading, subprocess
 
 # torch.set_float32_matmul_precision("high")
 
+# torch.set_default_dtype(torch.bfloat16)
+
 import diffusion
 
 ######################################################################
@@ -565,7 +567,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    scaler = torch.amp.GradScaler("cuda")
+    scaler = torch.amp.GradScaler("cuda")
 
     for x_0, mask_generate in ae_batches(
         quiz_machine,
@@ -581,29 +583,29 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        # with torch.amp.autocast("cuda"):
-        logits = diffuser.logits_hat_x_0_from_random_iteration(
-            model=model,
-            x_0=x_0,
-            mask_generate=mask_generate,
-            prompt_noise=args.prompt_noise,
-        )
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            logits = diffuser.logits_hat_x_0_from_random_iteration(
+                model=model,
+                x_0=x_0,
+                mask_generate=mask_generate,
+                prompt_noise=args.prompt_noise,
+            )
 
         loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
         acc_train_loss += loss.item() * x_0.size(0)
         nb_train_samples += x_0.size(0)
 
-        loss.backward()
+        loss.backward()
 
-        if nb_train_samples % args.batch_size == 0:
-            model.optimizer.step()
+        if nb_train_samples % args.batch_size == 0:
+        # model.optimizer.step()
 
-        scaler.scale(loss).backward()
+        scaler.scale(loss).backward()
 
-        if nb_train_samples % args.batch_size == 0:
-        # scaler.step(model.optimizer)
+        if nb_train_samples % args.batch_size == 0:
+            scaler.step(model.optimizer)
 
-        scaler.update()
+        scaler.update()
 
     log_string(
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"