From: François Fleuret Date: Wed, 10 Jul 2024 18:09:50 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=050976a525fee2d3b824350a3058ab7299a2bd3d;p=culture.git Update. --- diff --git a/main.py b/main.py index 0c193f7..1ef01e9 100755 --- a/main.py +++ b/main.py @@ -16,6 +16,8 @@ import ffutils import mygpt import sky, grids, quiz_machine +import threading + # world quizzes vs. culture quizzes ###################################################################### @@ -38,7 +40,7 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) -parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) ######################################## @@ -78,6 +80,8 @@ parser.add_argument("--problem", type=str, default="grids") parser.add_argument("--nb_threads", type=int, default=1) +parser.add_argument("--nb_gpus", type=int, default=1) + parser.add_argument("--nb_gpts", type=int, default=5) parser.add_argument("--min_to_validate", type=int, default=None) @@ -238,14 +242,14 @@ if args.problem == "sky": nb_birds=args.sky_nb_birds, nb_iterations=args.sky_nb_iterations, speed=args.sky_speed, - max_nb_cached_chunks=args.nb_train_samples // 100, + max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100, chunk_size=100, nb_threads=args.nb_threads, ) back_accuracy = False elif args.problem == "grids": problem = grids.Grids( - max_nb_cached_chunks=args.nb_train_samples // 100, + max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100, chunk_size=100, nb_threads=args.nb_threads, ) @@ -273,50 +277,23 @@ vocabulary_size = quiz_machine.vocabulary_size() log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### -############################## - - -def one_epoch(model, quiz_machine): - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - model.train() - - nb_train_samples, acc_train_loss = 0, 0.0 - - for input in quiz_machine.batches(model, split="train"): - input = input.to(device) - - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() - - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) - - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} {train_perplexity}") ###################################################################### -def run_tests(model, quiz_machine, deterministic_synthesis): +def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None): + if local_device is None: + local_device = device + with torch.autograd.no_grad(): - model.eval() + model.eval().to(local_device) nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 for input in quiz_machine.batches(model, split="test"): - input = input.to(device) + input = input.to(local_device) bs = model(mygpt.BracketedSequence(input)) output = bs.x @@ -339,6 +316,42 @@ def run_tests(model, quiz_machine, deterministic_synthesis): ) +def one_epoch(model, quiz_machine, local_device=None): + if local_device is None: + local_device = device + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + model.to(local_device).train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in quiz_machine.batches(model, split="train"): + input = input.to(local_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} {train_perplexity}") + + run_tests(model, quiz_machine, deterministic_synthesis=False) + + model.TRAINING_LOCK.release() + + ###################################################################### @@ -548,6 +561,7 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k + model.TRAINING_LOCK = threading.Lock() model.train_w_quizzes = quiz_machine.generate_token_sequences( args.nb_train_samples @@ -640,23 +654,24 @@ for n_epoch in range(args.nb_epochs): ################################################## # Select, improve, and eval the worst model - weakest_model = min(models, key=lambda m: float(m.main_test_accuracy)) + ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) - log_string( - f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}" - ) + weakest_models = ranked_models[: args.nb_gpus] - one_epoch(weakest_model, quiz_machine) + for gpu_id, model in enumerate(weakest_models): + model.TRAINING_LOCK.acquire() - log_string( - f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + log_string( + f"training model {model.id} main_test_accuracy {model.main_test_accuracy}" + ) - run_tests(weakest_model, quiz_machine, deterministic_synthesis=False) + threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") + ).start() - log_string( - f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + for model in weakest_models: + model.TRAINING_LOCK.acquire() + model.TRAINING_LOCK.release() ################################################## # Replace a fraction of the w_quizzes with fresh ones @@ -667,7 +682,8 @@ for n_epoch in range(args.nb_epochs): # Renew entirely the train set - quiz_machine.renew_w_quizzes(model, args.nb_train_samples) + for model in weakest_models: + quiz_machine.renew_w_quizzes(model, args.nb_train_samples) ################################################## # If all the models are good enough, generate new quizzes and @@ -681,8 +697,4 @@ for n_epoch in range(args.nb_epochs): nb_for_test=nb_new_c_quizzes_for_test, ) - for model in models: - run_tests(model, quiz_machine, deterministic_synthesis=False) - - ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index 34c09a7..1f1046d 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -15,6 +15,8 @@ from torch.nn import functional as F import mygpt from mygpt import BracketedSequence +import threading + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -235,22 +237,10 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - # self.reverse_random_half_in_place(self.train_w_quizzes) - - # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - # self.reverse_random_half_in_place(self.test_w_quizzes) - + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - # if result_dir is not None: - # self.save_quizzes( - # result_dir, - # "culture_w_quizzes", - # self.train_w_quizzes[:72], - # ) - def save_quizzes( self, result_dir, @@ -292,32 +282,34 @@ class QuizMachine: def batches(self, model, split="train", desc=None): assert split in {"train", "test"} - if split == "train": - w_quizzes = model.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = model.test_w_quizzes - c_quizzes = self.test_c_quizzes - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] + with self.LOCK_C_QUIZZES: + if split == "train": + w_quizzes = model.train_w_quizzes + c_quizzes = self.train_c_quizzes + else: + w_quizzes = model.test_w_quizzes + c_quizzes = self.test_c_quizzes + + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] + + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) - - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + input = torch.cat([w_quizzes, c_quizzes], dim=0) + else: + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 # Shuffle input = input[torch.randperm(input.size(0))] @@ -417,10 +409,11 @@ class QuizMachine: ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) + with self.LOCK_C_QUIZZES: + if for_train: + self.train_c_quizzes.append(new_c_quizzes) + else: + self.test_c_quizzes.append(new_c_quizzes) ######################################################################