From 56b7a714b294965d99d52859564970adef328952 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 27 Aug 2024 14:21:09 +0200 Subject: [PATCH] Update. --- main.py | 332 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 180 insertions(+), 152 deletions(-) diff --git a/main.py b/main.py index eb0f776..2e8ec43 100755 --- a/main.py +++ b/main.py @@ -93,7 +93,7 @@ parser.add_argument("--gpus", type=str, default="all") # ---------------------------------- -parser.add_argument("--nb_gpts", type=int, default=5) +parser.add_argument("--nb_models", type=int, default=5) parser.add_argument("--min_succeed_to_validate", type=int, default=2) @@ -464,6 +464,16 @@ c_quizzes_procedure = [ # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), ] +# quad_order, quad_generate, quad_noise, quad_loss + +data_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), + (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), +] + ###################################################################### @@ -783,6 +793,18 @@ class MyAttentionAE(nn.Module): trunk_blocks = [] for b in range(nb_blocks): + # if b == nb_blocks//2: + # trunk_blocks += [ + # QKVAttention( + # dim_in=dim_model, + # dim_qk=dim_keys, + # dim_v=dim_model // nb_heads, + # nb_heads=nb_heads, + # attention_dropout=dropout, + # ), + # VaswaniPositionalEncoding(len_max=1e5) + # ] + trunk_blocks += [ WithResidual( CacheWrapper( @@ -864,20 +886,6 @@ def ae_batches( mask_loss.to(local_device), ) - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename="a.png", - # quizzes=a, - # ) - - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename="b.png", - # quizzes=b, - # ) - - # time.sleep(1000) - def NTC_masked_cross_entropy(output, targets, mask): loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") @@ -892,6 +900,12 @@ def deterministic(mask_generate): return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long() +# This function returns a tensor of same shape as low, full of uniform +# random values in [0,1], such that the values corresponding to the +# True in low are all lesser than the values corresponding to the +# False. + + def prioritized_rand(low): x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values k = torch.rand(low.size(), device=low.device) + low.long() @@ -901,17 +915,15 @@ def prioritized_rand(low): return y -def ae_generate( - model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50 -): +def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) - input = (1 - mask_generate) * input + mask_generate * noise - proba_erased = noise_proba + input = (1 - mask_generate) * input + mask_generate * noise d = deterministic(mask_generate)[:, None] + changed = True for it in range(nb_iterations_max): @@ -922,7 +934,8 @@ def ae_generate( r = prioritized_rand(final != input) - mask_erased = mask_generate * (r <= proba_erased).long() + mask_erased = mask_generate * (r <= noise_proba).long() + mask_to_change = d * mask_generate + (1 - d) * mask_erased update = (1 - mask_to_change) * input + mask_to_change * final @@ -956,56 +969,22 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba): return result -def test_ae(local_device=main_device): - model = MyAttentionAE( - vocabulary_size=vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - dropout=args.dropout, - ).to(main_device) - - # quad_order, quad_generate, quad_noise, quad_loss - - data_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), - (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - ] - - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - model.to(local_device).train() - optimizer_to(model.optimizer, local_device) - - nb_iterations = 25 - probs_iterations = torch.arange(nb_iterations, device=main_device) - probs_iterations = 0.1 ** (probs_iterations / nb_iterations) - probs_iterations = probs_iterations[None, :] / probs_iterations.sum() +###################################################################### - for n_epoch in range(args.nb_epochs): - # ---------------------- - # Train - model.train() - nb_train_samples, acc_train_loss = 0, 0.0 +def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): + with torch.autograd.no_grad(): + model.eval().to(local_device) - noise_proba = 0.05 + nb_test_samples, acc_test_loss = 0, 0.0 for input, mask_generate, mask_loss in ae_batches( quiz_machine, - args.nb_train_samples, + args.nb_test_samples, data_structures, local_device, - "training", + "test", ): - if nb_train_samples % args.batch_size == 0: - model.optimizer.zero_grad() - d = deterministic(mask_generate) p = probs_iterations.expand(input.size(0), -1) dist = torch.distributions.categorical.Categorical(probs=p) @@ -1013,119 +992,168 @@ def test_ae(local_device=main_device): N1 = N0 + 1 N0 = (1 - d) * N0 N1 = (1 - d) * N1 + d * nb_iterations - targets, input = degrade_input( input, mask_generate, (0 * N1, N1), noise_proba=noise_proba ) - - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # for n in ["input", "targets"]: - # filename = f"{n}.png" - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename, - # quizzes=locals()[n], - # ) - # log_string(f"wrote {filename}") - # time.sleep(1000) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) - loss.backward() + log_string( + f"test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}" + ) - if nb_train_samples % args.batch_size == 0: - model.optimizer.step() + # ------------------------------------------- + # Test generation - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + nb_correct, nb_total, record = 0, 0, [] - log_string(f"train_loss {n_epoch} model AE {acc_train_loss/nb_train_samples}") + for input, mask_generate, mask_loss in ae_batches( + quiz_machine, + args.nb_test_samples, + data_structures, + local_device, + "test", + ): + targets = input.clone() + result = ae_generate( + model, (1 - mask_generate) * input, mask_generate, noise_proba + ) + correct = (result == targets).min(dim=1).values.long() + predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ + :, :, 1 + ] + solution_is_deterministic = predicted_parts.sum(dim=-1) == 1 + correct = (2 * correct - 1) * (solution_is_deterministic).long() + nb_correct += (correct == 1).long().sum() + nb_total += (correct != 0).long().sum() + correct_parts = predicted_parts * correct[:, None] + record.append((result, predicted_parts, correct_parts)) - # ---------------------- - # Test + log_string( + f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" + ) - with torch.autograd.no_grad(): - model.eval() + model.test_accuracy = nb_correct / nb_total - nb_test_samples, acc_test_loss = 0, 0.0 + filename = f"prediction_ae_{n_epoch:04d}.png" - for input, mask_generate, mask_loss in ae_batches( - quiz_machine, - args.nb_test_samples, - data_structures, - local_device, - "test", - ): - d = deterministic(mask_generate) - p = probs_iterations.expand(input.size(0), -1) - dist = torch.distributions.categorical.Categorical(probs=p) - N0 = dist.sample() - N1 = N0 + 1 - N0 = (1 - d) * N0 - N1 = (1 - d) * N1 + d * nb_iterations - targets, input = degrade_input( - input, mask_generate, (0 * N1, N1), noise_proba=noise_proba - ) - input_with_mask = NTC_channel_cat(input, mask_generate) - logits = model(input_with_mask) - loss = NTC_masked_cross_entropy(logits, targets, mask_loss) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) + result, predicted_parts, correct_parts = ( + torch.cat([x[i] for x in record]) for i in [0, 1, 2] + ) + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, + ) - log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}") + log_string(f"wrote {filename}") - # ------------------------------------------- - # Test generation - for ns, s in enumerate(data_structures): - quad_order, quad_generate, _, _ = s +###################################################################### - input, mask_generate, _ = next( - ae_batches(quiz_machine, 128, [s], local_device, batch_size=128) - ) - targets = input.clone() - input = ae_generate( - model, - input, - mask_generate, - n_epoch, - noise_proba=noise_proba, - ) +def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device): + model.train().to(local_device) - correct = (input == targets).min(dim=1).values.long() - predicted_parts = torch.tensor(quad_generate, device=input.device) - predicted_parts = predicted_parts[None, :].expand(input.size(0), -1) - solution_is_deterministic = predicted_parts.sum(dim=-1) == 1 - correct = (2 * correct - 1) * (solution_is_deterministic).long() - nb_correct = (correct == 1).long().sum() - nb_total = (correct != 0).long().sum() - correct_parts = predicted_parts * correct[:, None] - - log_string( - f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" - ) + nb_train_samples, acc_train_loss = 0, 0.0 - filename = f"prediction_ae_{n_epoch:04d}_{ns}.png" + for input, mask_generate, mask_loss in ae_batches( + quiz_machine, + args.nb_train_samples, + data_structures, + local_device, + "training", + ): + if nb_train_samples % args.batch_size == 0: + model.optimizer.zero_grad() - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=input, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - ) + d = deterministic(mask_generate) + p = probs_iterations.expand(input.size(0), -1) + dist = torch.distributions.categorical.Categorical(probs=p) + N0 = dist.sample() + N1 = N0 + 1 + N0 = (1 - d) * N0 + N1 = (1 - d) * N1 + d * nb_iterations - log_string(f"wrote {filename}") + targets, input = degrade_input( + input, mask_generate, (0 * N1, N1), noise_proba=noise_proba + ) + input_with_mask = NTC_channel_cat(input, mask_generate) + logits = model(input_with_mask) + loss = NTC_masked_cross_entropy(logits, targets, mask_loss) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) -if args.test == "ae": - test_ae(local_device=main_device) - exit(0) + loss.backward() + + if nb_train_samples % args.batch_size == 0: + model.optimizer.step() + + log_string( + f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" + ) + + run_ae_test(model, quiz_machine, n_epoch, local_device=local_device) + + +###################################################################### + +noise_proba = 0.05 + +nb_iterations = 25 +probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device) +probs_iterations = probs_iterations[None, :] / probs_iterations.sum() + +models = [] + +for i in range(args.nb_models): + model = MyAttentionAE( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + dropout=args.dropout, + ).to(main_device) + + model.id = i + model.test_accuracy = 0.0 + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + model.to(main_device).train() + optimizer_to(model.optimizer, main_device) + + models.append(model) + +for n_epoch in range(args.nb_epochs): + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) + weakest_models = ranked_models[: len(gpus)] + + threads = [] + + start_time = time.perf_counter() + + for gpu, model in zip(gpus, weakest_models): + log_string(f"training model {model.id}") + + t = threading.Thread( + target=one_ae_epoch, daemon=True, args=(model, quiz_machine, n_epoch, gpu) + ) + + threads.append(t) + + t.start() + + for t in threads: + t.join() ###################################################################### @@ -1136,7 +1164,7 @@ def create_models(): def compute_causal_attzero(t_q, t_k): return t_q < t_k - for k in range(args.nb_gpts): + for k in range(args.nb_models): log_string(f"creating model {k}") model = mygpt.MyGPT( @@ -1244,7 +1272,7 @@ log_string( if args.dirty_debug: args.accuracy_to_make_c_quizzes = 0.0 - args.nb_gpts = 2 + args.nb_models = 2 args.nb_new_c_quizzes_for_train = 100 args.nb_new_c_quizzes_for_test = 10 -- 2.39.5