From 304845428c84ed5697133f539032378b26dc240a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 22 Aug 2024 21:04:10 +0200 Subject: [PATCH] Update. --- main.py | 57 +++++++++++++++++++++++++++++++++++++++++++------------- mygpt.py | 2 +- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index 36b369e..b6c62cf 100755 --- a/main.py +++ b/main.py @@ -847,25 +847,24 @@ def test_ae(local_device=main_device): model.train() nb_train_samples, acc_train_loss = 0, 0.0 - data_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), - ] - full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples, data_structures=data_structures + args.nb_train_samples ) src = zip( - full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) + full_input.split(args.batch_size), + full_mask_generate.split(args.batch_size), + full_mask_loss.split(args.batch_size), ) - for input, mask_loss in tqdm.tqdm( + for input, mask_generate, mask_loss in tqdm.tqdm( src, dynamic_ncols=True, desc="training", total=full_input.size(0) // args.batch_size, ): input = input.to(local_device) + mask_generate = mask_generate.to(local_device) mask_loss = mask_loss.to(local_device) if nb_train_samples % args.batch_size == 0: @@ -896,21 +895,25 @@ def test_ae(local_device=main_device): nb_test_samples, acc_test_loss = 0, 0.0 full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( - args.nb_test_samples, data_structures=data_structures + args.nb_test_samples ) src = zip( - full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) + full_input.split(args.batch_size), + full_mask_generate.split(args.batch_size), + full_mask_loss.split(args.batch_size), ) - for input, mask_loss in tqdm.tqdm( + for input, mask_generate, mask_loss in tqdm.tqdm( src, dynamic_ncols=True, desc="testing", total=full_input.size(0) // args.batch_size, ): input = input.to(local_device) + mask_generate = mask_generate.to(local_device) mask_loss = mask_loss.to(local_device) + targets = input input = (mask_generate == 0).long() * input output = model(mygpt.BracketedSequence(input)).x @@ -920,10 +923,9 @@ def test_ae(local_device=main_device): log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}") - input, mask_generate, mask_loss = quiz_machine.data_input( - 128, data_structures=data_structures - ) + input, mask_generate, mask_loss = quiz_machine.data_input(128) input = input.to(local_device) + mask_generate = mask_generate.to(local_device) mask_loss = mask_loss.to(local_device) targets = input input = (mask_generate == 0).long() * input @@ -935,12 +937,41 @@ def test_ae(local_device=main_device): result[:, 1 * L] = input[:, 1 * L] result[:, 2 * L] = input[:, 2 * L] result[:, 3 * L] = input[:, 3 * L] + correct = (result == targets).min(dim=1).values.long() + predicted_parts = input.new(input.size(0), 4) + + nb = 0 + + # We consider all the configurations that we train for + for struct, quad_generate, _, _ in quiz_machine.test_structures: + i = quiz_machine.problem.indices_select(quizzes=input, struct=struct) + nb += i.long().sum() + + predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[ + None, : + ] + solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 + correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() + + assert nb == input.size(0) + + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() + + self.logger( + f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" + ) + + correct_parts = predicted_parts * correct[:, None] + filename = f"prediction_ae_{n_epoch:04d}.png" quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, ) log_string(f"wrote {filename}") diff --git a/mygpt.py b/mygpt.py index 8379a57..a744224 100755 --- a/mygpt.py +++ b/mygpt.py @@ -164,7 +164,7 @@ class TrainablePositionalEncoding(nn.Module): self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = ( - bs.slice() + self.pe[bs.first : bs.first + bs.nb] + bs.slice() + self.pe[:, bs.first : bs.first + bs.nb, :] ) return BracketedSequence(self.cache_y, bs.first, bs.nb) -- 2.39.5