From: François Fleuret Date: Fri, 23 Aug 2024 05:04:41 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1207bca5af71088ed11af346cfb98cd3c7ca4489;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 35b3cff..98a0581 100755 --- a/grids.py +++ b/grids.py @@ -148,30 +148,30 @@ class Grids(problem.Problem): ("gray", [128, 128, 128]), ] - def check_structure(self, quizzes, struct): + def check_order(self, quizzes, quad_order): S = self.height * self.width return ( - (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]]) - & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]]) - & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]]) - & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]]) + (quizzes[:, 0 * (S + 1)] == self.l2tok[quad_order[0]]) + & (quizzes[:, 1 * (S + 1)] == self.l2tok[quad_order[1]]) + & (quizzes[:, 2 * (S + 1)] == self.l2tok[quad_order[2]]) + & (quizzes[:, 3 * (S + 1)] == self.l2tok[quad_order[3]]) ).all() - def get_structure(self, quizzes): + def get_order(self, quizzes): S = self.height * self.width - struct = tuple( + quad_order = tuple( self.tok2l[n.item()] for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0] ) - self.check_structure(quizzes, struct) - return struct + self.check_order(quizzes, quad_order) + return quad_order - def inject_noise(self, quizzes, noise, struct, quad): - assert self.check_structure(quizzes, struct=struct) + def inject_noise(self, quizzes, noise, quad_order, quad_noise): + assert self.check_order(quizzes, quad_order=quad_order) S = self.height * self.width - mask = torch.tensor(quad, device=quizzes.device) + mask = torch.tensor(quad_noise, device=quizzes.device) mask = mask[None, :, None].expand(1, 4, S + 1).clone() mask[:, :, 0] = 0 mask = mask.reshape(1, -1).expand_as(quizzes) @@ -182,20 +182,20 @@ class Grids(problem.Problem): return quizzes # What a mess - def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): + def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")): if torch.is_tensor(quizzes): - return self.reconfigure([quizzes], struct=struct)[0] + return self.reconfigure([quizzes], quad_order=quad_order)[0] S = self.height * self.width result = [x.new(x.size()) for x in quizzes] - struct_from = self.get_structure(quizzes[0][:1]) - i = self.indices_select(quizzes[0], struct_from) + quad_order_from = self.get_order(quizzes[0][:1]) + i = self.indices_select(quizzes[0], quad_order_from) - sf = dict((l, n) for n, l in enumerate(struct_from)) + sf = dict((l, n) for n, l in enumerate(quad_order_from)) for q in range(4): - k = sf[struct[q]] + k = sf[quad_order[q]] for x, y in zip(quizzes, result): l = x.size(1) // 4 y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l] @@ -204,7 +204,7 @@ class Grids(problem.Problem): if j.any(): for z, y in zip( - self.reconfigure([x[j] for x in quizzes], struct=struct), result + self.reconfigure([x[j] for x in quizzes], quad_order=quad_order), result ): y[j] = z @@ -212,36 +212,36 @@ class Grids(problem.Problem): def trivial(self, quizzes): S = self.height * self.width - assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B")) + assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B")) a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min( dim=1 ).values def make_quiz_mask( - self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1) + self, quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=(0, 0, 0, 1) ): - assert self.check_structure(quizzes, struct) + assert self.check_order(quizzes, quad_order) ar_mask = quizzes.new_zeros(quizzes.size()) S = self.height * self.width a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:] - a[:, 0, :] = quad[0] - a[:, 1, :] = quad[1] - a[:, 2, :] = quad[2] - a[:, 3, :] = quad[3] + a[:, 0, :] = quad_mask[0] + a[:, 1, :] = quad_mask[1] + a[:, 2, :] = quad_mask[2] + a[:, 3, :] = quad_mask[3] return ar_mask - def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")): + def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")): S = self.height * self.width q = quizzes.reshape(quizzes.size(0), 4, S + 1) return ( - (q[:, 0, 0] == self.l2tok[struct[0]]) - & (q[:, 1, 0] == self.l2tok[struct[1]]) - & (q[:, 2, 0] == self.l2tok[struct[2]]) - & (q[:, 3, 0] == self.l2tok[struct[3]]) + (q[:, 0, 0] == self.l2tok[quad_order[0]]) + & (q[:, 1, 0] == self.l2tok[quad_order[1]]) + & (q[:, 2, 0] == self.l2tok[quad_order[2]]) + & (q[:, 3, 0] == self.l2tok[quad_order[3]]) ) def __init__( @@ -1707,13 +1707,13 @@ class Grids(problem.Problem): ###################################################################### - def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")): + def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")): S = self.height * self.width quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64) - quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]] - quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]] - quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]] - quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]] + quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]] + quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]] + quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]] + quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]] return quizzes @@ -1764,10 +1764,10 @@ if __name__ == "__main__": # nb = 5 # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) # print(quizzes) - # print(grids.get_structure(quizzes)) + # print(grids.get_order(quizzes)) # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) # print("DEBUG2", quizzes) - # print(grids.get_structure(quizzes)) + # print(grids.get_order(quizzes)) # print(quizzes) # i = torch.rand(quizzes.size(0)) < 0.5 @@ -1778,8 +1778,8 @@ if __name__ == "__main__": # print( # i.equal(j), - # grids.get_structure(quizzes[j]), - # grids.get_structure(quizzes[j == False]), + # grids.get_order(quizzes[j]), + # grids.get_order(quizzes[j == False]), # ) # exit(0) diff --git a/main.py b/main.py index 2a35209..a65d893 100755 --- a/main.py +++ b/main.py @@ -541,7 +541,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict( model, solved_c_quizzes[:, model.id], - struct=("A", "f_A", "B", "f_B"), + quad_orders=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1), ) @@ -821,6 +821,33 @@ class MyAttentionVAE(nn.Module): return bs +def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None): + full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( + args.nb_train_samples, data_structures=data_structures + ) + + src = zip( + full_input.split(args.batch_size), + full_mask_generate.split(args.batch_size), + full_mask_loss.split(args.batch_size), + ) + + if desc is not None: + src = tqdm.tqdm( + src, + dynamic_ncols=True, + desc=desc, + total=full_input.size(0) // args.batch_size, + ) + + for input, mask_generate, mask_loss in src: + yield ( + input.to(local_device), + mask_generate.to(local_device), + mask_loss.to(local_device), + ) + + def test_ae(local_device=main_device): model = MyAttentionVAE( vocabulary_size=vocabulary_size, @@ -832,6 +859,14 @@ def test_ae(local_device=main_device): dropout=args.dropout, ).to(main_device) + data_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (1, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (1, 1, 1, 0), (0, 0, 0, 0), (1, 1, 1, 1)), + ] + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) model.to(local_device).train() @@ -847,26 +882,13 @@ def test_ae(local_device=main_device): model.train() nb_train_samples, acc_train_loss = 0, 0.0 - full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples - ) - - src = zip( - full_input.split(args.batch_size), - full_mask_generate.split(args.batch_size), - full_mask_loss.split(args.batch_size), - ) - - for input, mask_generate, mask_loss in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="training", - total=full_input.size(0) // args.batch_size, + for input, mask_generate, mask_loss in ae_batches( + quiz_machine, + args.nb_train_samples, + data_structures, + local_device, + "training", ): - input = input.to(local_device) - mask_generate = mask_generate.to(local_device) - mask_loss = mask_loss.to(local_device) - if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() @@ -911,26 +933,13 @@ def test_ae(local_device=main_device): nb_test_samples, acc_test_loss = 0, 0.0 - full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( - args.nb_test_samples - ) - - src = zip( - full_input.split(args.batch_size), - full_mask_generate.split(args.batch_size), - full_mask_loss.split(args.batch_size), - ) - - for input, mask_generate, mask_loss in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="testing", - total=full_input.size(0) // args.batch_size, + for input, mask_generate, mask_loss in ae_batches( + quiz_machine, + args.nb_test_samples, + data_structures, + local_device, + "test", ): - input = input.to(local_device) - mask_generate = mask_generate.to(local_device) - mask_loss = mask_loss.to(local_device) - targets = input mask_noise = (mask_generate != 0) & ( @@ -960,10 +969,10 @@ def test_ae(local_device=main_device): log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}") - input, mask_generate, mask_loss = quiz_machine.data_input(128) - input = input.to(local_device) - mask_generate = mask_generate.to(local_device) - mask_loss = mask_loss.to(local_device) + input, mask_generate, mask_loss = next( + ae_batches(quiz_machine, 128, data_structures, local_device) + ) + targets = input pred_result = None @@ -1013,8 +1022,10 @@ def test_ae(local_device=main_device): nb = 0 # We consider all the configurations that we train for - for struct, quad_generate, _, _ in quiz_machine.test_structures: - i = quiz_machine.problem.indices_select(quizzes=input, struct=struct) + for quad_order, quad_generate, _, _ in quiz_machine.test_structures: + i = quiz_machine.problem.indices_select( + quizzes=input, quad_order=quad_order + ) nb += i.long().sum() predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[ diff --git a/quiz_machine.py b/quiz_machine.py index bea0d78..0f13964 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -175,39 +175,46 @@ class QuizMachine: quizzes = quizzes[i] self.randomize_configuations_inplace( - quizzes, structs=[s for s, _, _, _ in data_structures] + quizzes, quad_orders=[s for s, _, _, _ in data_structures] ) quiz_mask_generate = quizzes.new_full(quizzes.size(), 1) quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - for struct, quad_generate, quad_noise, quad_loss in data_structures: - i = self.problem.indices_select(quizzes=quizzes, struct=struct) + for quad_order, quad_generate, quad_noise, quad_loss in data_structures: + i = self.problem.indices_select(quizzes=quizzes, quad_order=quad_order) if i.any(): if self.prompt_noise > 0.0: quizzes[i] = self.problem.inject_noise( - quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise + quizzes[i], + self.prompt_noise, + quad_order=quad_order, + quad_noise=quad_noise, ) quiz_mask_generate[i] = self.make_quiz_mask( - quizzes=quizzes[i], struct=struct, quad=quad_generate + quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate ) quiz_mask_loss[i] = self.make_quiz_mask( - quizzes=quizzes[i], struct=struct, quad=quad_loss + quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss ) return quizzes, quiz_mask_generate, quiz_mask_loss ###################################################################### - def make_quiz_mask(self, quizzes, struct, quad): - assert struct in [s for s, _, _, _ in self.train_structures] - return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad) + def make_quiz_mask(self, quizzes, quad_order, quad_mask): + assert quad_order in [s for s, _, _, _ in self.train_structures] + return self.problem.make_quiz_mask( + quizzes, quad_order=quad_order, quad_mask=quad_mask + ) ###################################################################### - def predict(self, model, quizzes, struct, quad): + def predict(self, model, quizzes, quad_order, quad_mask): quizzes = quizzes.to(self.device) - ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad) + ar_mask = self.make_quiz_mask( + quizzes=quizzes, quad_order=quad_order, quad_mask=quad_mask + ) result = quizzes * (1 - ar_mask) seq_logprobas = torch.zeros(quizzes.size(0), device=self.device) @@ -239,11 +246,11 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, quad_generate, _, _ in self.test_structures: - i = self.problem.indices_select(quizzes=input, struct=struct) + for quad_order, quad_generate, _, _ in self.test_structures: + i = self.problem.indices_select(quizzes=input, quad_order=quad_order) nb += i.long().sum() result[i], correct[i], _ = self.predict( - model=model, quizzes=input[i], struct=struct, quad=quad_generate + model=model, quizzes=input[i], quad_order=quad_order, quad=quad_generate ) predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[ @@ -282,11 +289,11 @@ class QuizMachine: ###################################################################### - def randomize_configuations_inplace(self, quizzes, structs): - r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device) - for c in range(len(structs)): + def randomize_configuations_inplace(self, quizzes, quad_orders): + r = torch.randint(len(quad_orders), (quizzes.size(0),), device=quizzes.device) + for c in range(len(quad_orders)): quizzes[r == c] = self.problem.reconfigure( - quizzes[r == c], struct=structs[c] + quizzes[r == c], quad_order=quad_orders[c] ) ###################################################################### @@ -310,7 +317,7 @@ class QuizMachine: self, model, c_quizzes, - struct, + quad_order, quad_loss, quad_noise=None, temperature=1.0, @@ -319,7 +326,7 @@ class QuizMachine: if device is None: device = self.device - c_quizzes = self.problem.reconfigure(c_quizzes, struct) + c_quizzes = self.problem.reconfigure(c_quizzes, quad_order) seq_logprobas = torch.zeros( c_quizzes.size(0), @@ -328,7 +335,7 @@ class QuizMachine: # if self.prompt_noise > 0.0 and quad_noise is not None: # c_quizzes = self.problem.inject_noise( - # c_quizzes, self.prompt_noise, struct=struct, quad=quad_noise + # c_quizzes, self.prompt_noise, quad_order=quad_order, quad_noise=quad_noise # ) with torch.autograd.no_grad(): @@ -341,7 +348,7 @@ class QuizMachine: ): input = input.to(device) quiz_mask_loss = self.make_quiz_mask( - input, struct=struct, quad=quad_loss + input, quad_order=quad_order, quad_mask=quad_loss ) output = model(mygpt.BracketedSequence(input)).x / temperature l[...] = ( @@ -361,13 +368,13 @@ class QuizMachine: c_quizzes = None for n_step, setup in enumerate(procedure): - struct, quad_generate, model_modifier = setup + quad_order, quad_generate, model_modifier = setup if c_quizzes is None: - c_quizzes = self.problem.create_empty_quizzes(nb, struct) + c_quizzes = self.problem.create_empty_quizzes(nb, quad_order) c_quizzes = c_quizzes.to(self.device) - elif struct != pred_struct: - c_quizzes = self.problem.reconfigure(c_quizzes, struct) - pred_struct = struct + elif quad_order != pred_quad_order: + c_quizzes = self.problem.reconfigure(c_quizzes, quad_order) + pred_quad_order = quad_order if model_modifier is not None: model_modifier(model_for_generation) @@ -375,7 +382,9 @@ class QuizMachine: self.autoregression( model=model_for_generation, input=c_quizzes, - ar_mask=self.make_quiz_mask(c_quizzes, struct, quad_generate), + ar_mask=self.make_quiz_mask( + quizzes=c_quizzes, quad_order=quad_order, quad_mask=quad_generate + ), seq_logprobas=seq_logprobas, progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}", )