From d34deee897329af00296a656bb3bb88617d814e5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 12 Sep 2024 14:36:08 +0200 Subject: [PATCH] Update. --- main.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/main.py b/main.py index 3753e9b..f6f37d6 100755 --- a/main.py +++ b/main.py @@ -634,6 +634,7 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): return x_t_minus_1 +###################################################################### # Non-uniform transitions, to be fixed? @@ -1350,6 +1351,68 @@ c_quizzes = None time_c_quizzes = 0 time_train = 0 +###################################################################### + + +def multithread_execution(fun, arguments): + if len(arguments) == 1: + return fun(*(arguments[0])) + + records, threads = [], [] + + def threadable_fun(*args): + records.append(fun(*args)) + + for args in arguments: + t = threading.Thread(target=threadable_fun, daemon=True, args=args) + + # To get a different sequence between threads + log_string(f"dummy_rand {torch.rand(1)}") + threads.append(t) + t.start() + + for t in threads: + t.join() + + if records == []: + return + else: + return [ + torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0])) + ] + + +# ----- test + +# nb_gpus = len(gpus) +# nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus + +# c_quizzes, agreements = multithread_execution( +# generate_ae_c_quizzes, +# [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], +# ) + +# ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) +# weakest_models = ranked_models[: len(gpus)] + +# n_epoch = 14 + +# multithread_execution( +# one_ae_epoch, +# [ +# ( +# model, +# quiz_machine, +# n_epoch, +# None if c_quizzes is None else c_quizzes[agreements[:, model.id]], +# gpu, +# ) +# for model, gpu in zip(weakest_models, gpus) +# ], +# ) + +###################################################################### + for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() -- 2.39.5