From: François Fleuret Date: Thu, 19 Sep 2024 14:16:03 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b480e061a9753ad9cfbe82cdd21636587b96566d;p=culture.git Update. --- diff --git a/main.py b/main.py index c08b04d..7c8c836 100755 --- a/main.py +++ b/main.py @@ -249,16 +249,6 @@ assert args.nb_test_samples % args.batch_size == 0 ###################################################################### -problem = grids.Grids( - max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, - chunk_size=100, - nb_threads=args.nb_threads, - tasks=args.grids_world_tasks, -) - -if not args.resume: - problem.save_some_examples(args.result_dir) - def pure_noise(nb, device): r = problem.pure_noise(nb, device) @@ -308,14 +298,6 @@ def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): ###################################################################### -log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") - -vocabulary_size = problem.vocabulary_size() - -log_string(f"vocabulary_size {vocabulary_size}") - -###################################################################### - def optimizer_to(optim, device): """Move the optimizer optim to the device""" @@ -637,30 +619,6 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): ###################################################################### -models = [] - -for i in range(args.nb_models): - # model = attae.FunctionalAttentionAE( - model = attae.AttentionAE( - vocabulary_size=vocabulary_size * 2, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - dropout=args.dropout, - ) - - # model = torch.compile(model) - - model.id = i - model.test_accuracy = 0.0 - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - models.append(model) - -###################################################################### - def evaluate_quizzes(quizzes, models, local_device): nb_correct, nb_wrong = 0, 0 @@ -768,6 +726,62 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): ###################################################################### +def multithread_execution(fun, arguments): + # Single instance, no thread + if len(arguments) == 1: + return fun(*(arguments[0])) + + records, threads = [], [] + + def threadable_fun(*args): + r = fun(*args) + if type(r) is not tuple: + r = (r,) + records.append(r) + + for args in arguments: + # To get a different sequence between threads + # log_string(f"dummy_rand {torch.rand(1)}") + torch.rand(1) + t = threading.Thread(target=threadable_fun, daemon=True, args=args) + threads.append(t) + t.start() + + for t in threads: + t.join() + + if records[0] == (None,): + return + else: + return [ + torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0])) + ] + + +###################################################################### + + +def save_models(models, suffix=""): + if suffix != "": + suffix = "_" + suffix + + for model in models: + filename = f"ae_{model.id:03d}{suffix}.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "test_accuracy": model.test_accuracy, + }, + os.path.join(args.result_dir, filename), + ) + + log_string(f"wrote ae_*{suffix}.pth") + + +###################################################################### + + def save_quiz_image(models, c_quizzes, filename, local_device=main_device): c_quizzes = c_quizzes.to(local_device) @@ -793,6 +807,49 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device): ###################################################################### +problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks=args.grids_world_tasks, +) + +if not args.resume: + problem.save_some_examples(args.result_dir) + + +log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") + +vocabulary_size = problem.vocabulary_size() + +log_string(f"vocabulary_size {vocabulary_size}") + +###################################################################### + +models = [] + +for i in range(args.nb_models): + # model = attae.FunctionalAttentionAE( + model = attae.AttentionAE( + vocabulary_size=vocabulary_size * 2, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + dropout=args.dropout, + ) + + # model = torch.compile(model) + + model.id = i + model.test_accuracy = 0.0 + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + models.append(model) + +###################################################################### + current_epoch = 0 if args.resume: @@ -837,62 +894,6 @@ c_quizzes = None ###################################################################### - -def multithread_execution(fun, arguments): - # Single instance, no thread - if len(arguments) == 1: - return fun(*(arguments[0])) - - records, threads = [], [] - - def threadable_fun(*args): - r = fun(*args) - if type(r) is not tuple: - r = (r,) - records.append(r) - - for args in arguments: - # To get a different sequence between threads - # log_string(f"dummy_rand {torch.rand(1)}") - torch.rand(1) - t = threading.Thread(target=threadable_fun, daemon=True, args=args) - threads.append(t) - t.start() - - for t in threads: - t.join() - - if records[0] == (None,): - return - else: - return [ - torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0])) - ] - - -###################################################################### - - -def save_models(models, suffix=""): - if suffix != "": - suffix = "_" + suffix - - for model in models: - filename = f"ae_{model.id:03d}{suffix}.pth" - torch.save( - { - "state_dict": model.state_dict(), - "optimizer_state_dict": model.optimizer.state_dict(), - "test_accuracy": model.test_accuracy, - }, - os.path.join(args.result_dir, filename), - ) - - log_string(f"wrote ae_*{suffix}.pth") - - -###################################################################### - for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter()