From ef093a189280ca6bb6bbd55a5e3f7ef0e4ed0e8f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 16 Sep 2024 23:25:33 +0200 Subject: [PATCH] Update. --- main.py | 46 +++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/main.py b/main.py index d21c54b..899a099 100755 --- a/main.py +++ b/main.py @@ -640,22 +640,22 @@ def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device): def batch_prediction(input, proba_hints=0.0): nb = input.size(0) - mask_generate = input.new_zeros(input.size()) - u = F.one_hot(torch.randint(4, (nb,), device=mask_generate.device), num_classes=4) - mask_generate.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] + mask = input.new_zeros(input.size()) + u = F.one_hot(torch.randint(4, (nb,), device=mask.device), num_classes=4) + mask.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] if proba_hints > 0: - h = torch.rand(input.size(), device=input.device) * mask_generate + h = torch.rand(input.size(), device=input.device) * mask mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints v = torch.rand(nb, device=input.device)[:, None] mask_hints = mask_hints * (v < proba_hints).long() - mask_generate = (1 - mask_hints) * mask_generate + mask = (1 - mask_hints) * mask # noise = quiz_machine.problem.pure_noise(nb, input.device) targets = input - input = (1 - mask_generate) * targets # + mask_generate * noise + input = (1 - mask) * targets # + mask * noise - return input, targets, mask_generate + return input, targets, mask def predict(model, input, targets, mask, local_device=main_device): @@ -704,10 +704,10 @@ def batch_generation(input): targets = input input = (1 - mask_erased) * input + mask_erased * noise - mask_generate = input.new_full(input.size(), 1) - mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0 + mask = input.new_full(input.size(), 1) + mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0 - return input, targets, mask_generate + return input, targets, mask def prioritized_rand(low): @@ -721,20 +721,19 @@ def prioritized_rand(low): def generate(model, nb, local_device=main_device): input = quiz_machine.problem.pure_noise(nb, local_device) - mask_generate = input.new_full(input.size(), 1) - mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0 + mask = input.new_full(input.size(), 1) + mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0 changed = True - for it in range(self.diffusion_nb_iterations): + for it in range(args.diffusion_nb_iterations): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(NTC_channel_cat(input, mask_generate)) + logits = model(NTC_channel_cat(input, mask)) dist = torch.distributions.categorical.Categorical(logits=logits) output = dist.sample() - r = self.prioritized_rand(input != output) - mask_changes = (r <= self.proba_corruption).long() + r = prioritized_rand(input != output) + mask_changes = (r <= args.diffusion_proba_corruption).long() * mask update = (1 - mask_changes) * input + mask_changes * output - if update.equal(input): break else: @@ -803,9 +802,13 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): + # train + one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True) one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False) + # predict + quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier) input, targets, mask = batch_prediction(quizzes.to(local_device)) result = predict(model, input, targets, mask).to("cpu") @@ -825,6 +828,15 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): model.test_accuracy = correct.sum() / quizzes.size(0) + # generate + + result = generate(model, 25).to("cpu") + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"culture_generation_{n_epoch}_{model.id}.png", + quizzes=result[:128], + ) + ###################################################################### -- 2.39.5