Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 08:01:10 +0000 (10:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 08:01:10 +0000 (10:01 +0200)
main.py

diff --git a/main.py b/main.py
index 43a8774..c0f9e57 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -862,7 +862,7 @@ class FunctionalAE(nn.Module):
         self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
 
         def trunk(nb, bottom=True):
-            trunk_blocks = []
+            trunk_blocks = [VaswaniPositionalEncoding(len_max=1e5)]
 
             la = [
                 QKVAttention(
@@ -872,7 +872,6 @@ class FunctionalAE(nn.Module):
                     nb_heads=nb_heads,
                     attention_dropout=dropout,
                 ),
-                VaswaniPositionalEncoding(len_max=1e5),
             ]
 
             # if not bottom:
@@ -929,6 +928,9 @@ class FunctionalAE(nn.Module):
         theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :]
         theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :]
 
+        # if self.hook_theta is not None:
+        # self.hook_theta(theta_A, theta_B)
+
         hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L]
         hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L]
 
@@ -1171,7 +1173,7 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de
 
         model.test_accuracy = nb_correct / nb_total
 
-        for f, record in [("prediction", record_d), ("generative", record_nd)]:
+        for f, record in [("prediction", record_d), ("generation", record_nd)]:
             filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
             result, predicted_parts, correct_parts = (
                 torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2]
@@ -1194,6 +1196,31 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de
             )
             log_string(f"wrote {filename}")
 
+        # Prediction with functional perturbations
+
+        # input, mask_generate, mask_loss = next(
+        # ae_batches(
+        # quiz_machine,
+        # [
+        # (
+        # ("A", "f_A", "B", "f_B"),
+        # (0, 0, 0, 1),
+        # (0, 0, 1, 0),
+        # (0, 0, 0, 1),
+        # ),
+        # ],
+        # local_device,
+        # desc=None,
+        # )
+        # )
+        # targets = input.clone()
+        # p = torch.rand(4,model.f_tokens.size(1)).sort(dim=1).indices
+        # def change_theta(theta_A, theta_B):
+        # theta
+        # result = ae_generate(
+        # model, (1 - mask_generate) * input, mask_generate, noise_proba
+        # )
+
 
 ######################################################################
 
@@ -1371,7 +1398,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     duration = time.perf_counter() - start_time
     str_duration = ""
     if duration >= 60:
-        str_duration += f"{int(duration//60)}min"
-        duration = duration % 60
-    str_duration += f"{duration:.01f}s"
-    log_string(f"epoch_duration {str_duration}")
+        str_duration += f"{int(duration)//60}min"
+    str_duration += f"{int(duration)%60}s"
+    str_next = (
+        datetime.datetime.now() + datetime.timedelta(seconds=duration)
+    ).strftime("%H:%M:%S")
+    log_string(f"epoch_duration {str_duration} next_finish {str_next}")