From: François Fleuret Date: Sun, 1 Sep 2024 19:54:45 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=9b03df47520cff5c5da7f0655861a64ffc9c0e1a;p=culture.git Update. --- diff --git a/main.py b/main.py index 9801702..120e19c 100755 --- a/main.py +++ b/main.py @@ -1134,7 +1134,7 @@ def targets_and_prediction(model, input, mask_generate): return targets, logits -def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): +def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device): with torch.autograd.no_grad(): model.eval().to(local_device) @@ -1147,7 +1147,8 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): args.nb_test_samples, data_structures, local_device, - "test", + c_quizzes=c_quizzes, + desc="test", ): targets, logits = targets_and_prediction(model, input, mask_generate) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) @@ -1167,6 +1168,7 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): args.nb_test_samples, data_structures, local_device, + c_quizzes, "test", ): targets = input.clone() @@ -1282,7 +1284,7 @@ def one_ae_epoch( f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device) + run_ae_test(model, quiz_machine, n_epoch, c_quizzes, local_device=local_device) ######################################################################