From ff912e032df5168acbe6fc3b5879136db6e515cf Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 22 Aug 2024 20:44:02 +0200 Subject: [PATCH] Update. --- main.py | 33 +++++++++++++++++++-------------- mygpt.py | 24 ++++++++++++++++++++++-- quiz_machine.py | 11 ++++++----- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index fc480b7..36b369e 100755 --- a/main.py +++ b/main.py @@ -341,7 +341,7 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 - full_input, full_mask_loss = quiz_machine.data_input( + full_input, _, full_mask_loss = quiz_machine.data_input( args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier ) src = zip( @@ -370,7 +370,7 @@ def run_tests(model, quiz_machine, local_device=main_device): log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - input, _ = quiz_machine.data_input( + input, _, _ = quiz_machine.data_input( 2000, model.test_c_quiz_bags, args.c_quiz_multiplier ) @@ -394,7 +394,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 - full_input, full_mask_loss = quiz_machine.data_input( + full_input, _, full_mask_loss = quiz_machine.data_input( args.nb_train_samples, model.train_c_quiz_bags + common_c_quiz_bags, args.c_quiz_multiplier, @@ -635,7 +635,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): from mygpt import ( WithResidual, CacheWrapper, - AddPositionalEncoding, + VaswaniPositionalEncoding, + TrainablePositionalEncoding, QKVAttention, BracketedSequence, ) @@ -660,7 +661,7 @@ class Thinker(nn.Module): self.embedding = nn.Sequential( CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), + VaswaniPositionalEncoding(len_max), ) def trunk(depth): @@ -743,7 +744,7 @@ class Thinker(nn.Module): from mygpt import ( WithResidual, CacheWrapper, - AddPositionalEncoding, + VaswaniPositionalEncoding, QKVAttention, BracketedSequence, ) @@ -759,7 +760,7 @@ class MyAttentionVAE(nn.Module): nb_heads, nb_blocks, dropout=0.0, - len_max=1e5, + len_max=1024, ): super().__init__() @@ -769,7 +770,7 @@ class MyAttentionVAE(nn.Module): CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), ) - self.positional_encoding = AddPositionalEncoding(len_max) + self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max) trunk_blocks = [] @@ -850,7 +851,7 @@ def test_ae(local_device=main_device): (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), ] - full_input, full_mask_loss = quiz_machine.data_input( + full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( args.nb_train_samples, data_structures=data_structures ) @@ -871,7 +872,7 @@ def test_ae(local_device=main_device): model.optimizer.zero_grad() targets = input - input = (mask_loss == 0).long() * input + input = (mask_generate == 0).long() * input output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) @@ -894,7 +895,9 @@ def test_ae(local_device=main_device): nb_test_samples, acc_test_loss = 0, 0.0 - full_input, full_mask_loss = quiz_machine.data_input(args.nb_test_samples) + full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( + args.nb_test_samples, data_structures=data_structures + ) src = zip( full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) @@ -909,7 +912,7 @@ def test_ae(local_device=main_device): input = input.to(local_device) mask_loss = mask_loss.to(local_device) targets = input - input = (mask_loss == 0).long() * input + input = (mask_generate == 0).long() * input output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_test_loss += loss.item() * input.size(0) @@ -917,11 +920,13 @@ def test_ae(local_device=main_device): log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}") - input, mask_loss = quiz_machine.data_input(128) + input, mask_generate, mask_loss = quiz_machine.data_input( + 128, data_structures=data_structures + ) input = input.to(local_device) mask_loss = mask_loss.to(local_device) targets = input - input = (mask_loss == 0).long() * input + input = (mask_generate == 0).long() * input logits = model(mygpt.BracketedSequence(input)).x dist = torch.distributions.categorical.Categorical(logits=logits) result = dist.sample() diff --git a/mygpt.py b/mygpt.py index f716fe5..8379a57 100755 --- a/mygpt.py +++ b/mygpt.py @@ -122,7 +122,7 @@ class WithResidual(nn.Module): ############################## -class AddPositionalEncoding(nn.Module): +class VaswaniPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max @@ -153,6 +153,26 @@ class AddPositionalEncoding(nn.Module): ############################## +class TrainablePositionalEncoding(nn.Module): + def __init__(self, dim, len_max): + super().__init__() + self.len_max = len_max + self.pe = nn.Parameter(torch.randn(1, len_max, dim) / math.sqrt(dim)) + + def forward(self, bs): + if bs.first == 0: + self.cache_y = bs.x.new(bs.x.size()) + + self.cache_y[:, bs.first : bs.first + bs.nb] = ( + bs.slice() + self.pe[bs.first : bs.first + bs.nb] + ) + + return BracketedSequence(self.cache_y, bs.first, bs.nb) + + +############################## + + class EncoderHead(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() @@ -338,7 +358,7 @@ class MyGPT(nn.Module): CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), ) - self.positional_encoding = AddPositionalEncoding(len_max) + self.positional_encoding = VaswaniPositionalEncoding(len_max) trunk_blocks = [] diff --git a/quiz_machine.py b/quiz_machine.py index 08f121a..bea0d78 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -178,23 +178,24 @@ class QuizMachine: quizzes, structs=[s for s, _, _, _ in data_structures] ) + quiz_mask_generate = quizzes.new_full(quizzes.size(), 1) quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - for struct, _, quad_noise, quad_loss in data_structures: + for struct, quad_generate, quad_noise, quad_loss in data_structures: i = self.problem.indices_select(quizzes=quizzes, struct=struct) if i.any(): if self.prompt_noise > 0.0: quizzes[i] = self.problem.inject_noise( quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise ) + quiz_mask_generate[i] = self.make_quiz_mask( + quizzes=quizzes[i], struct=struct, quad=quad_generate + ) quiz_mask_loss[i] = self.make_quiz_mask( quizzes=quizzes[i], struct=struct, quad=quad_loss ) - print("quad_loss", quad_loss) - print("quiz_mask_loss", quiz_mask_loss) - - return quizzes, quiz_mask_loss + return quizzes, quiz_mask_generate, quiz_mask_loss ###################################################################### -- 2.39.5