From: François Fleuret Date: Sun, 15 Sep 2024 10:18:18 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b487f34019b2415b5290d1409014e132f86f63d7;p=culture.git Update. --- diff --git a/diffusion.py b/diffusion.py index 98d8d0a..8c6e08d 100755 --- a/diffusion.py +++ b/diffusion.py @@ -92,8 +92,8 @@ class Diffuser: x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - # with torch.amp.autocast("cuda"): - logits_hat_x_0 = model(x_t_with_mask) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits_hat_x_0 = model(x_t_with_mask) return logits_hat_x_0 @@ -117,8 +117,8 @@ class Diffuser: for it in range(self.nb_iterations): x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - # with torch.amp.autocast("cuda"): - logits = model(x_t_with_mask) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(x_t_with_mask) # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf") dist = torch.distributions.categorical.Categorical(logits=logits) diff --git a/main.py b/main.py index 534bab9..19a3fee 100755 --- a/main.py +++ b/main.py @@ -27,6 +27,8 @@ import threading, subprocess # torch.set_float32_matmul_precision("high") +# torch.set_default_dtype(torch.bfloat16) + import diffusion ###################################################################### @@ -565,7 +567,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi nb_train_samples, acc_train_loss = 0, 0.0 - # scaler = torch.amp.GradScaler("cuda") + scaler = torch.amp.GradScaler("cuda") for x_0, mask_generate in ae_batches( quiz_machine, @@ -581,29 +583,29 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - # with torch.amp.autocast("cuda"): - logits = diffuser.logits_hat_x_0_from_random_iteration( - model=model, - x_0=x_0, - mask_generate=mask_generate, - prompt_noise=args.prompt_noise, - ) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = diffuser.logits_hat_x_0_from_random_iteration( + model=model, + x_0=x_0, + mask_generate=mask_generate, + prompt_noise=args.prompt_noise, + ) loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) acc_train_loss += loss.item() * x_0.size(0) nb_train_samples += x_0.size(0) - loss.backward() + # loss.backward() - if nb_train_samples % args.batch_size == 0: - model.optimizer.step() + # if nb_train_samples % args.batch_size == 0: + # model.optimizer.step() - # scaler.scale(loss).backward() + scaler.scale(loss).backward() - # if nb_train_samples % args.batch_size == 0: - # scaler.step(model.optimizer) + if nb_train_samples % args.batch_size == 0: + scaler.step(model.optimizer) - # scaler.update() + scaler.update() log_string( f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"