From 4895ed0b5f877d0d8b7b740d1822e10c04393474 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 17 Sep 2024 09:45:36 +0200 Subject: [PATCH] Update. --- main.py | 305 ++++++-------------------------------------------------- 1 file changed, 33 insertions(+), 272 deletions(-) diff --git a/main.py b/main.py index 3339838..8054509 100755 --- a/main.py +++ b/main.py @@ -390,91 +390,6 @@ data_structures = [ ###################################################################### -def model_proba_solutions(model, input, log_probas=False, reduce=True): - record = [] - - for x_0 in input.split(args.batch_size): - loss = 0 - - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: - mask_generate = quiz_machine.make_quiz_mask( - quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad - ) - logits = logits_hat_x_0_from_random_iteration( - model=model, - x_0=x_0, - mask_generate=mask_generate, - prompt_noise=args.prompt_noise, - ) - loss_per_token = F.cross_entropy( - logits.transpose(1, 2), x_0, reduction="none" - ) - if reduce: - loss += (loss_per_token * mask_generate).sum(dim=1) - else: - loss += loss_per_token * mask_generate - - record.append(loss) - - loss = torch.cat(record, dim=0) - - if log_probas: - return -loss - else: - return (-loss).exp() - - -###################################################################### - - -def batches( - quiz_machine, - nb, - data_structures, - local_device, - c_quizzes=None, - alien_quiz_machine=None, - desc=None, - batch_size=args.batch_size, -): - c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")] - - full_input, full_mask_generate, _ = quiz_machine.data_input( - nb, - c_quiz_bags, - data_structures=data_structures, - c_quiz_multiplier=args.c_quiz_multiplier, - ) - - src = zip( - full_input.split(batch_size), - full_mask_generate.split(batch_size), - ) - - if desc is not None: - src = tqdm.tqdm( - src, - dynamic_ncols=True, - desc=desc, - total=full_input.size(0) // batch_size, - ) - - for input, mask_generate in src: - yield ( - input.to(local_device), - mask_generate.to(local_device), - ) - - -def NTC_channel_cat(*x): - return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) - - -def NTC_masked_cross_entropy(output, targets, mask): - loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") - return (loss_per_token * mask).mean() - - def masked_cross_entropy(output, targets, masks): loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum() @@ -482,165 +397,7 @@ def masked_cross_entropy(output, targets, masks): ###################################################################### - -def run_test( - model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None -): - if prefix is None: - prefix = "" - else: - prefix = prefix + "_" - - with torch.autograd.no_grad(): - model.eval().to(local_device) - - # Compute the loss - - nb_test_samples, acc_test_loss = 0, 0.0 - - for x_0, mask_generate in batches( - quiz_machine, - args.nb_test_samples, - data_structures, - local_device, - c_quizzes=c_quizzes, - desc="test", - ): - logits = diffuser.logits_hat_x_0_from_random_iteration( - model=model, - x_0=x_0, - mask_generate=mask_generate, - ) - loss = masked_cross_entropy(logits, x_0, mask_generate) - acc_test_loss += loss.item() * x_0.size(0) - nb_test_samples += x_0.size(0) - - log_string( - f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}" - ) - - # Compute the accuracy and save some images - - nb_correct, nb_total, record_d, record_nd = 0, 0, [], [] - - for x_0, mask_generate in batches( - quiz_machine, - args.nb_test_samples, - data_structures, - local_device, - c_quizzes=c_quizzes, - desc="test", - ): - result = diffuser.generate(model, (1 - mask_generate) * x_0, mask_generate) - correct = (result == x_0).min(dim=1).values.long() - predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ - :, :, 1 - ] - d = predicted_parts.sum(dim=-1) == 1 - correct = (2 * correct - 1) * d.long() - nb_correct += (correct == 1).long().sum() - nb_total += (correct != 0).long().sum() - correct_parts = predicted_parts * correct[:, None] - record_d.append((result[d], predicted_parts[d], correct_parts[d])) - nd = d == False - record_nd.append((result[nd], predicted_parts[nd], correct_parts[nd])) - - log_string( - f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" - ) - - # Save some images - - for f, record in [("prediction", record_d), ("generation", record_nd)]: - result, predicted_parts, correct_parts = bag_to_tensors(record) - - filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result[:128], - predicted_parts=predicted_parts[:128], - correct_parts=correct_parts[:128], - ) - - log_string(f"wrote {filename}") - - return nb_correct / nb_total - - -###################################################################### - - -def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device): - model.train().to(local_device) - optimizer_to(model.optimizer, local_device) - - nb_train_samples, acc_train_loss = 0, 0.0 - - # scaler = torch.amp.GradScaler("cuda") - - for x_0, mask_generate in batches( - quiz_machine, - args.nb_train_samples, - data_structures, - local_device, - c_quizzes=c_quizzes, - desc="training", - ): - x_0 = x_0.to(local_device) - mask_generate = mask_generate.to(local_device) - - if nb_train_samples % args.batch_size == 0: - model.optimizer.zero_grad() - - nb_hints = torch.randint(2, (x_0.size(0),), device=x_0.device) * args.nb_hints - - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = diffuser.logits_hat_x_0_from_random_iteration( - model=model, - x_0=x_0, - mask_generate=mask_generate, - prompt_noise=args.prompt_noise, - nb_hints=nb_hints, - ) - - loss = masked_cross_entropy(logits, x_0, mask_generate) - acc_train_loss += loss.item() * x_0.size(0) - nb_train_samples += x_0.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - model.optimizer.step() - - # scaler.scale(loss).backward() - - # if nb_train_samples % args.batch_size == 0: - # scaler.step(model.optimizer) - - # scaler.update() - - log_string( - f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" - ) - - model.test_accuracy = run_test( - model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device - ) - - if args.nb_test_alien_samples > 0: - run_test( - model, - alien_quiz_machine, - n_epoch, - c_quizzes=None, - local_device=local_device, - prefix="alien", - ) - - -###################################################################### +# IMT for input / masks / target def IMT_batch_prediction(input, proba_hints=0.0): @@ -687,8 +444,6 @@ def predict(model, imt_set, local_device=main_device): ###################################################################### -# IMT for input / masks / target - def IMT_batch_generation(input): nb = input.size(0) @@ -723,25 +478,34 @@ def prioritized_rand(low): def generate(model, nb, local_device=main_device): - input = quiz_machine.problem.pure_noise(nb, local_device) - masks = input.new_full(input.size(), 1) - masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0 - - changed = True - for it in range(args.diffusion_nb_iterations): - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(input) - dist = torch.distributions.categorical.Categorical(logits=logits) - output = dist.sample() - - r = prioritized_rand(input != output) - mask_changes = (r <= args.diffusion_proba_corruption).long() * masks - update = (1 - mask_changes) * input + mask_changes * output - if update.equal(input): - break - else: - changed = changed & (update != input).max(dim=1).values - input[changed] = update[changed] + all_input = quiz_machine.problem.pure_noise(nb, local_device) + all_masks = all_input.new_full(all_input.size(), 1) + all_masks.reshape(all_masks.size(0), 4, -1)[:, :, 0] = 0 + + for input, masks in tqdm.tqdm( + zip( + all_input.split(args.physical_batch_size), + all_masks.split(args.physical_batch_size), + ), + dynamic_ncols=True, + desc="predict", + total=all_input.size(0) // args.physical_batch_size, + ): + changed = True + for it in range(args.diffusion_nb_iterations): + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(input) + dist = torch.distributions.categorical.Categorical(logits=logits) + output = dist.sample() + + r = prioritized_rand(input != output) + mask_changes = (r <= args.diffusion_proba_corruption).long() * masks + update = (1 - mask_changes) * input + mask_changes * output + if update.equal(input): + break + else: + changed = changed & (update != input).max(dim=1).values + input[changed] = update[changed] return input @@ -749,10 +513,6 @@ def generate(model, nb, local_device=main_device): ###################################################################### -def batch_interleave(a, b, perm): - return torch.cat([a, b])[perm].reshape(-1, args.physical_batch_size, a.size(1)) - - def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): if train: label = "train" @@ -770,9 +530,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): args.c_quiz_multiplier, ) - input_p, input_g = quizzes.to(local_device).chunk(2) + q1, q2 = quizzes.to(local_device).chunk(2) + imt_set = torch.cat( - [IMT_batch_prediction(input_p, proba_hints=0.5), IMT_batch_generation(input_g)] + [IMT_batch_prediction(q1, proba_hints=0.5), IMT_batch_generation(q2)] ) imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] @@ -831,7 +592,7 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): # generate - result = generate(model, 25, local_device=local_device).to("cpu") + result = generate(model, 150, local_device=local_device).to("cpu") quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"culture_generation_{n_epoch}_{model.id}.png", -- 2.39.5