From 42b3fcd36bf580aaa2f005625def3adbe4cc2794 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 14 Aug 2024 13:04:54 +0200 Subject: [PATCH] Update. --- main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 0b9a86e..0bbcc6b 100755 --- a/main.py +++ b/main.py @@ -528,7 +528,7 @@ def model_proba_solutions(m, quizzes): def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) - nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate + nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate // 10 start_time = time.perf_counter() @@ -759,11 +759,10 @@ class Thinker(nn.Module): if args.test == "func": - train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) - L = train_input.size(1) // 4 - f_len = 25 + L = test_input.size(1) // 4 + f_len = 50 model = Thinker( vocabulary_size=vocabulary_size, @@ -772,7 +771,7 @@ if args.test == "func": dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - f_len=20, + f_len=f_len, dropout=args.dropout, ).to(main_device) @@ -781,6 +780,8 @@ if args.test == "func": for n_epoch in range(args.nb_epochs): model.train() + train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) + nb_train_samples, acc_train_loss = 0, 0.0 for input in tqdm.tqdm( -- 2.39.5