From 870549ee2031656af3a66a235d4b1413cc30c1f4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 31 Aug 2024 10:01:10 +0200 Subject: [PATCH] Update. --- main.py | 43 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 43a8774..c0f9e57 100755 --- a/main.py +++ b/main.py @@ -862,7 +862,7 @@ class FunctionalAE(nn.Module): self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5) def trunk(nb, bottom=True): - trunk_blocks = [] + trunk_blocks = [VaswaniPositionalEncoding(len_max=1e5)] la = [ QKVAttention( @@ -872,7 +872,6 @@ class FunctionalAE(nn.Module): nb_heads=nb_heads, attention_dropout=dropout, ), - VaswaniPositionalEncoding(len_max=1e5), ] # if not bottom: @@ -929,6 +928,9 @@ class FunctionalAE(nn.Module): theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :] theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :] + # if self.hook_theta is not None: + # self.hook_theta(theta_A, theta_B) + hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L] hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L] @@ -1171,7 +1173,7 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de model.test_accuracy = nb_correct / nb_total - for f, record in [("prediction", record_d), ("generative", record_nd)]: + for f, record in [("prediction", record_d), ("generation", record_nd)]: filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" result, predicted_parts, correct_parts = ( torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2] @@ -1194,6 +1196,31 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de ) log_string(f"wrote {filename}") + # Prediction with functional perturbations + + # input, mask_generate, mask_loss = next( + # ae_batches( + # quiz_machine, + # [ + # ( + # ("A", "f_A", "B", "f_B"), + # (0, 0, 0, 1), + # (0, 0, 1, 0), + # (0, 0, 0, 1), + # ), + # ], + # local_device, + # desc=None, + # ) + # ) + # targets = input.clone() + # p = torch.rand(4,model.f_tokens.size(1)).sort(dim=1).indices + # def change_theta(theta_A, theta_B): + # theta + # result = ae_generate( + # model, (1 - mask_generate) * input, mask_generate, noise_proba + # ) + ###################################################################### @@ -1371,7 +1398,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): duration = time.perf_counter() - start_time str_duration = "" if duration >= 60: - str_duration += f"{int(duration//60)}min" - duration = duration % 60 - str_duration += f"{duration:.01f}s" - log_string(f"epoch_duration {str_duration}") + str_duration += f"{int(duration)//60}min" + str_duration += f"{int(duration)%60}s" + str_next = ( + datetime.datetime.now() + datetime.timedelta(seconds=duration) + ).strftime("%H:%M:%S") + log_string(f"epoch_duration {str_duration} next_finish {str_next}") -- 2.39.5