From: François Fleuret Date: Fri, 13 Sep 2024 07:57:51 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=d3a114008c4767136502784ac08d5ed6d4dde39f;p=culture.git Update. --- diff --git a/attae.py b/attae.py index 069772b..bc90ed0 100755 --- a/attae.py +++ b/attae.py @@ -45,19 +45,17 @@ class WithResidual(nn.Module): ###################################################################### -class vanilla_attention(q, k, v): +def vanilla_attention(q, k, v): a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) a = a.softmax(dim=3) y = torch.einsum("nhts,nhsd->nhtd", a, v) - - # y = flex_attention(q, k, v, score_mod=noop) - y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) - return y -vanilla_attention = torch.compille(vanilla_attention) +vanilla_attention = torch.compile(vanilla_attention) + +# y = flex_attention(q, k, v, score_mod=noop) class MHAttention(nn.Module): @@ -93,7 +91,7 @@ class MHAttention(nn.Module): def noop(score, b, h, q_idx, kv_idx): return score - y = vanilla_attention(q, k, v, score_mod=noop) + y = vanilla_attention(q, k, v) # y = flex_attention(q, k, v, score_mod=noop) y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) @@ -163,7 +161,6 @@ class AttentionAE(nn.Module): m.weight.fill_(1.0) def forward(self, x): - x = 2 * x[:, :, 0] + x[:, :, 1] x = self.embedding(x) x = self.positional_encoding(x) x = self.trunk(x) diff --git a/main.py b/main.py index 63cd377..0fea318 100755 --- a/main.py +++ b/main.py @@ -55,7 +55,7 @@ parser.add_argument("--inference_batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=25000) -parser.add_argument("--nb_test_samples", type=int, default=1000) +parser.add_argument("--nb_test_samples", type=int, default=10000) parser.add_argument("--nb_train_alien_samples", type=int, default=0) @@ -1388,9 +1388,26 @@ def multithread_execution(fun, arguments): ###################################################################### -for n_epoch in range(current_epoch, args.nb_epochs): - start_time = time.perf_counter() +def save_models(models, suffix=""): + if suffix is not "": + suffix = "_" + suffix + for model in models: + filename = f"ae_{model.id:03d}{suffix}.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "test_accuracy": model.test_accuracy, + }, + os.path.join(args.result_dir, filename), + ) + log_string(f"wrote {filename}") + + +###################################################################### + +for n_epoch in range(current_epoch, args.nb_epochs): state = { "current_epoch": n_epoch, "c_quizzes": c_quizzes, @@ -1414,46 +1431,37 @@ for n_epoch in range(current_epoch, args.nb_epochs): and time_train >= time_c_quizzes ): if c_quizzes is None: - for model in models: - filename = f"ae_{model.id:03d}_naive.pth" - torch.save( - { - "state_dict": model.state_dict(), - "optimizer_state_dict": model.optimizer.state_dict(), - "test_accuracy": model.test_accuracy, - }, - os.path.join(args.result_dir, filename), - ) - log_string(f"wrote {filename}") - - # -------------------------------------------------------------------- + save_models(models, "naive") last_n_epoch_c_quizzes = n_epoch nb_gpus = len(gpus) nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus - # -------------------------------------------------------------------- + start_time = time.perf_counter() c_quizzes, agreements = multithread_execution( generate_ae_c_quizzes, [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], ) - # -------------------------------------------------------------------- - - filename = f"culture_c_quiz_{n_epoch:04d}.png" save_c_quizzes_with_scores( - models, c_quizzes[:256], filename, solvable_only=False + models, + c_quizzes[:256], + f"culture_c_quiz_{n_epoch:04d}.png", + solvable_only=False, ) - filename = f"culture_c_quiz_{n_epoch:04d}_solvable.png" save_c_quizzes_with_scores( - models, c_quizzes[:256], filename, solvable_only=True + models, + c_quizzes[:256], + f"culture_c_quiz_{n_epoch:04d}_solvable.png", + solvable_only=True, ) log_string(f"generated_c_quizzes {c_quizzes.size()=}") time_train = 0 + for model in models: model.test_accuracy = 0 @@ -1467,6 +1475,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] + start_time = time.perf_counter() + multithread_execution( one_ae_epoch, [ @@ -1485,19 +1495,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- - for model in models: - filename = f"ae_{model.id:03d}.pth" - torch.save( - { - "state_dict": model.state_dict(), - "optimizer_state_dict": model.optimizer.state_dict(), - "test_accuracy": model.test_accuracy, - }, - os.path.join(args.result_dir, filename), - ) - log_string(f"wrote {filename}") - - # -------------------------------------------------------------------- + save_models(models) duration = time.perf_counter() - start_time str_duration = ""