From f1c942d6464644b4490e9450bc458a7a31cabb3e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 14 Sep 2024 21:17:18 +0200 Subject: [PATCH] Update. --- main.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 62cbd2f..b751374 100755 --- a/main.py +++ b/main.py @@ -1362,10 +1362,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): if n_epoch == 0: args = args[:1] - c_quizzes, agreements = multithread_execution( - generate_ae_c_quizzes, - args, - ) + c_quizzes, agreements = multithread_execution(generate_ae_c_quizzes, args) save_c_quizzes_with_scores( models, @@ -1398,13 +1395,17 @@ for n_epoch in range(current_epoch, args.nb_epochs): # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], - multithread_execution( - one_ae_epoch, - [ - (model, quiz_machine, n_epoch, c_quizzes, gpu) - for model, gpu in zip(weakest_models, gpus) - ], - ) + args = [ + (model, quiz_machine, n_epoch, c_quizzes, gpu) + for model, gpu in zip(weakest_models, gpus) + ] + + # Ugly hack: Only one thread during the first epoch so that + # compilation of the model does not explode + if n_epoch == 0: + args = args[:1] + + multithread_execution(one_ae_epoch, args) # -------------------------------------------------------------------- -- 2.39.5