From edba2f3ee7ade2dfdba5a6556b7c35296d422ccd Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 14 Sep 2024 11:48:04 +0200 Subject: [PATCH] Update. --- main.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index dede204..8010fa4 100755 --- a/main.py +++ b/main.py @@ -57,7 +57,7 @@ parser.add_argument("--inference_batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=25000) -parser.add_argument("--nb_test_samples", type=int, default=10000) +parser.add_argument("--nb_test_samples", type=int, default=1000) parser.add_argument("--nb_train_alien_samples", type=int, default=0) @@ -719,7 +719,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - logits_hat_x_0 = model(x_t_with_mask) + + with torch.cuda.amp.autocast(): + logits_hat_x_0 = model(x_t_with_mask) return logits_hat_x_0 @@ -743,7 +745,8 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None for it in range(nb_iterations_max): x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - logits = model(x_t_with_mask) + with torch.cuda.amp.autocast(): + logits = model(x_t_with_mask) logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf") dist = torch.distributions.categorical.Categorical(logits=logits) @@ -891,6 +894,8 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi nb_train_samples, acc_train_loss = 0, 0.0 + scaler = torch.cuda.amp.GradScaler() + for x_0, mask_generate in ae_batches( quiz_machine, args.nb_train_samples, @@ -905,18 +910,21 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - logits = logits_hat_x_0_from_random_iteration( - model, x_0, mask_generate, prompt_noise=args.prompt_noise - ) + with torch.cuda.amp.autocast(): + logits = logits_hat_x_0_from_random_iteration( + model, x_0, mask_generate, prompt_noise=args.prompt_noise + ) loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) acc_train_loss += loss.item() * x_0.size(0) nb_train_samples += x_0.size(0) - loss.backward() + scaler.scale(loss).backward() if nb_train_samples % args.batch_size == 0: - model.optimizer.step() + scaler.step(model.optimizer) + + scaler.update() log_string( f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" -- 2.39.5