From 1e7e3d3877ae98737852b78f35a98030dcc0701f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 11 Jul 2024 16:15:03 +0200 Subject: [PATCH] Update. --- main.py | 392 ++++++++++++++++++++++++++--------------------------- problem.py | 3 - 2 files changed, 194 insertions(+), 201 deletions(-) diff --git a/main.py b/main.py index 0a266a8..73e7ca2 100755 --- a/main.py +++ b/main.py @@ -16,182 +16,12 @@ import ffutils import mygpt import sky, grids, quiz_machine -import torch.multiprocessing as mp - -# mp.set_start_method('spawn') +import threading # world quizzes vs. culture quizzes ###################################################################### - -def log_string(s): - t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) - - if log_file is not None: - log_file.write(t + s + "\n") - log_file.flush() - - print(t + s) - sys.stdout.flush() - - -###################################################################### - - -def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None): - if local_device is None: - local_device = device - - with torch.autograd.no_grad(): - model.eval().to(local_device) - - nb_test_samples, acc_test_loss = 0, 0.0 - nb_samples_accumulated = 0 - - for input in quiz_machine.batches(model, split="test"): - input = input.to(local_device) - - bs = model(mygpt.BracketedSequence(input)) - output = bs.x - - loss = F.cross_entropy(output.transpose(1, 2), input) - - acc_test_loss += loss.item() * input.size(0) - - nb_test_samples += input.size(0) - - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - - model.main_test_accuracy = quiz_machine.produce_results( - n_epoch=n_epoch, - model=model, - result_dir=args.result_dir, - deterministic_synthesis=deterministic_synthesis, - ) - - -def one_epoch(model, quiz_machine, local_device=None): - if local_device is None: - local_device = device - - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - model.to(local_device).train() - - nb_train_samples, acc_train_loss = 0, 0.0 - - for input in quiz_machine.batches(model, split="train"): - input = input.to(local_device) - - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() - - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) - - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}") - - run_tests(model, quiz_machine, deterministic_synthesis=False) - - -###################################################################### - - -def standard_validity(logproba): - l = logproba.sort(dim=-1).values - return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) - # warnings.warn("TEST!!!", RuntimeWarning) - # print(l.exp()) - # return (l[:, 0] < math.log(0.99)) - - -def valid_c_quizzes(recorded, criteria): - result = [q[criteria(lp)] for q, lp in recorded] - return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) - - -###################################################################### - - -def create_c_quizzes( - models, - quiz_machine, - nb_for_train=1000, - nb_for_test=100, -): - quizzes_and_logproba_records = [] - - nb_to_create = nb_for_train + nb_for_test - - # ------------------------------------------------------------ - - file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") - - with open(file_name, "w") as logp_file: - while ( - valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0) - < nb_to_create - ): - # Select a model at random to generate the new quizzes - - model_for_generation = models[torch.randint(len(models), (1,))] - - c_quizzes = quiz_machine.generate_quizzes( - nb_to_create, - model_for_generation=model_for_generation, - temperature=args.generation_temperature, - ) - - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - - if c_quizzes.size(0) > 0: - logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) - for l in logproba: - s = " ".join([str(x.item()) for x in l]) - logp_file.write(s + "\n") - quizzes_and_logproba_records.append((c_quizzes, logproba)) - - nb_validated = valid_c_quizzes( - quizzes_and_logproba_records, standard_validity - ).size(0) - - log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" - ) - - # store the new c_quizzes which have been validated - - new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity) - - quiz_machine.reverse_random_half_in_place(new_c_quizzes) - - quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) - quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - - # save a bunch of images to investigate what quizzes with a - # certain nb of correct predictions look like - - q = new_c_quizzes[:72] - - if q.size(0) > 0: - quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q) - - -###################################################################### - if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cuda.matmul.allow_tf32 = True @@ -258,7 +88,7 @@ parser.add_argument("--min_to_validate", type=int, default=None) parser.add_argument("--max_to_validate", type=int, default=None) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--generation_temperature", type=float, default=2.0) @@ -359,11 +189,6 @@ except FileExistsError: log_file = open(os.path.join(args.result_dir, args.log_filename), "a") -log_string(f"argv {' '.join(sys.argv)}") - -for n in vars(args): - log_string(f"args.{n} {getattr(args, n)}") - if args.seed >= 0: # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False @@ -374,6 +199,26 @@ if args.seed >= 0: ###################################################################### + +def log_string(s): + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +log_string(f"argv {' '.join(sys.argv)}") + +for n in vars(args): + log_string(f"args.{n} {getattr(args, n)}") + + +###################################################################### + if args.dirty_debug: args.nb_train_samples = 2500 args.nb_test_samples = 100 @@ -408,8 +253,6 @@ elif args.problem == "grids": else: raise ValueError -problem.save_some_examples(args.result_dir) - quiz_machine = quiz_machine.QuizMachine( problem=problem, nb_train_samples=args.nb_train_samples, @@ -431,6 +274,165 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### + +###################################################################### + + +def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None): + if local_device is None: + local_device = device + + with torch.autograd.no_grad(): + model.eval().to(local_device) + + nb_test_samples, acc_test_loss = 0, 0.0 + nb_samples_accumulated = 0 + + for input in quiz_machine.batches(model, split="test"): + input = input.to(local_device) + + bs = model(mygpt.BracketedSequence(input)) + output = bs.x + + loss = F.cross_entropy(output.transpose(1, 2), input) + + acc_test_loss += loss.item() * input.size(0) + + nb_test_samples += input.size(0) + + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string(f"test_perplexity {n_epoch} {test_perplexity}") + + model.main_test_accuracy = quiz_machine.produce_results( + n_epoch=n_epoch, + model=model, + result_dir=args.result_dir, + deterministic_synthesis=deterministic_synthesis, + ) + + +def one_epoch(model, quiz_machine, local_device=None): + if local_device is None: + local_device = device + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + model.to(local_device).train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in quiz_machine.batches(model, split="train"): + input = input.to(local_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} {train_perplexity}") + + run_tests(model, quiz_machine, deterministic_synthesis=False) + + model.TRAINING_LOCK.release() + + +###################################################################### + + +def standard_validity(logproba): + l = logproba.sort(dim=-1).values + return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) + # warnings.warn("TEST!!!", RuntimeWarning) + # print(l.exp()) + # return (l[:, 0] < math.log(0.99)) + + +def valid_c_quizzes(recorded, criteria): + result = [q[criteria(lp)] for q, lp in recorded] + return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) + + +###################################################################### + + +def create_c_quizzes( + models, + quiz_machine, + nb_for_train=1000, + nb_for_test=100, +): + quizzes_and_logproba_records = [] + + nb_to_create = nb_for_train + nb_for_test + + # ------------------------------------------------------------ + + file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + + with open(file_name, "w") as logp_file: + while ( + valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0) + < nb_to_create + ): + # Select a model at random to generate the new quizzes + + model_for_generation = models[torch.randint(len(models), (1,))] + + c_quizzes = quiz_machine.generate_quizzes( + nb_to_create, + model_for_generation=model_for_generation, + temperature=args.generation_temperature, + ) + + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + + if c_quizzes.size(0) > 0: + logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) + for l in logproba: + s = " ".join([str(x.item()) for x in l]) + logp_file.write(s + "\n") + quizzes_and_logproba_records.append((c_quizzes, logproba)) + + nb_validated = valid_c_quizzes( + quizzes_and_logproba_records, standard_validity + ).size(0) + + log_string( + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" + ) + + # store the new c_quizzes which have been validated + + new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity) + + quiz_machine.reverse_random_half_in_place(new_c_quizzes) + + quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) + quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) + + # save a bunch of images to investigate what quizzes with a + # certain nb of correct predictions look like + + q = new_c_quizzes[:72] + + if q.size(0) > 0: + quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q) + + +###################################################################### + models = [] for k in range(args.nb_gpts): @@ -448,6 +450,7 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k + model.TRAINING_LOCK = threading.Lock() model.train_w_quizzes = quiz_machine.generate_token_sequences( args.nb_train_samples @@ -529,11 +532,6 @@ if args.dirty_debug: nb_new_c_quizzes_for_train = 100 nb_new_c_quizzes_for_test = 10 - def standard_validity(logproba): - l = logproba.sort(dim=-1).values - return l[:, 0] < math.log(0.99) - - ###################################################################### for n_epoch in range(args.nb_epochs): @@ -543,38 +541,36 @@ for n_epoch in range(args.nb_epochs): log_string(f"current_test_accuracies {cta}") ################################################## - # Select, improve, and eval the worst models + # Select, improve, and eval the worst model ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) weakest_models = ranked_models[: args.nb_gpus] - processes = [] - for gpu_id, model in enumerate(weakest_models): + model.TRAINING_LOCK.acquire() + log_string( f"training model {model.id} main_test_accuracy {model.main_test_accuracy}" ) - process = mp.Process( - target=one_epoch, args=(model, quiz_machine, f"cuda:{gpu_id}") - ) - - processes.append(process) - - for process in processes: - process.start() + threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") + ).start() - for process in processes: - process.join() + for model in weakest_models: + model.TRAINING_LOCK.acquire() + model.TRAINING_LOCK.release() ################################################## - # Renew the train sets + # Replace a fraction of the w_quizzes with fresh ones log_string( f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes" ) + # Renew entirely the train set + for model in weakest_models: quiz_machine.renew_w_quizzes(model, args.nb_train_samples) diff --git a/problem.py b/problem.py index 7eeb6b4..617b2a8 100755 --- a/problem.py +++ b/problem.py @@ -88,6 +88,3 @@ class Problem: prompts, answers = prompts[:-k], answers[:-k] return prompts, answers - - def save_some_examples(self, result_dir): - pass -- 2.39.5