From: François Fleuret Date: Fri, 26 Jul 2024 07:40:04 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1a90398bab34cc30685394608a2429383295d71b;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 67a5c97..3453d4a 100755 --- a/grids.py +++ b/grids.py @@ -336,12 +336,14 @@ class Grids(problem.Problem): predicted_parts=None, correct_parts=None, comments=None, - comment_height=64, + comment_height=48, nrow=4, - margin=8, + margin=8,ff ): quizzes = quizzes.to("cpu") - self.check_structure(quizzes, ("A", "f_A", "B", "f_B")) + + if not self.check_structure(quizzes, ("A", "f_A", "B", "f_B")): + print(f"**WARNING** {filename} is not in A/f_A/B/f_B order") S = self.height * self.width @@ -1490,6 +1492,34 @@ class Grids(problem.Problem): X[i2:ii, jj1:jj2] = c[4] f_X[i2:ii, jj1:jj2] = c[4] + def task_science_dot(self, A, f_A, B, f_B): + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + X[...] = 0 + f_X[...] = 0 + r = self.rec_coo(nb_rec, prevent_overlap=True) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + q = 0 + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + if i >= i1 and i < i2: + q += 1 + f_X[i, j1:j2] = c[-1] + if j >= j1 and j < j2: + q += 1 + f_X[i1:i2, j] = c[-1] + X[i, j] = c[-1] + f_X[i, j] = c[-1] + if q >= 2: + break + # end_tasks ###################################################################### @@ -1529,13 +1559,13 @@ class Grids(problem.Problem): return quizzes - def save_some_examples(self, result_dir): + def save_some_examples(self, result_dir, prefix=""): nb, nrow = 128, 4 for t in self.all_tasks: print(t.__name__) quizzes = self.generate_w_quizzes_(nb, tasks=[t]) self.save_quizzes_as_image( - result_dir, t.__name__ + ".png", quizzes, nrow=nrow + result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow ) @@ -1588,7 +1618,7 @@ if __name__ == "__main__": # for t in grids.all_tasks: - for t in [grids.task_science_implicit]: + for t in [grids.task_science_dot]: print(t.__name__) quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) grids.save_quizzes_as_image( diff --git a/main.py b/main.py index 4d618cc..4310e6b 100755 --- a/main.py +++ b/main.py @@ -95,7 +95,7 @@ parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--temperature_hot", type=float, default=1.5) +parser.add_argument("--temperature_hot", type=float, default=2) parser.add_argument("--temperature_cold", type=float, default=0.75) @@ -323,6 +323,9 @@ elif args.problem == "grids": tasks=args.grids_science_tasks, ) science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples) + if not args.resume: + problem.save_some_examples(args.result_dir, "science_") + else: raise ValueError @@ -447,6 +450,71 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### +def save_additional_results(models, science_w_quizzes): + for model in models: + c_quizzes = quiz_machine.generate_c_quizzes( + 128, + model_for_generation=model, + temperature_hot=args.temperature_hot, + temperature_cold=args.temperature_cold, + ) + + c_quizzes = quiz_machine.problem.reconfigure( + c_quizzes, ("A", "f_A", "B", "f_B") + ) + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"non_validated_{n_epoch:04d}_{model.id:02d}.png", + c_quizzes, + ) + + ###################################################################### + + if science_w_quizzes is not None: + for model in models: + struct = ("A", "f_A", "B", "f_B") + mask = (0, 0, 0, 1) + result, correct = quiz_machine.predict( + model=model, + quizzes=science_w_quizzes.to(main_device), + struct=struct, + mask=mask, + ) + + predicted_parts = torch.tensor(mask, device=correct.device)[None, :] + correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long() + + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() + + log_string( + f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" + ) + + i = correct == 1 + j = correct != 1 + + result = torch.cat([result[i], result[j]], dim=0) + correct = torch.cat([correct[i], correct[j]], dim=0) + correct_parts = predicted_parts * correct[:, None] + + result = result[:128] + predicted_parts = predicted_parts[:128] + correct_parts = correct_parts[:128] + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"culture_science_{n_epoch:04d}_{model.id:02d}.png", + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, + ) + + +###################################################################### + + def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): nb_to_validate = nb_for_train + nb_for_test nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate) @@ -562,7 +630,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 if vq.size(0) > 0: number_correct_responses = 0 - for r in range(args.nb_rounds): + for r in range(10): number_correct_responses += quiz_machine.models_successes(models, vq) comments = [] @@ -740,39 +808,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) log_string(f"wrote {filename}") - for model in weakest_models: - c_quizzes = quiz_machine.generate_c_quizzes( - 128, - model_for_generation=model, - temperature_hot=args.temperature_hot, - temperature_cold=args.temperature_cold, - ) - - c_quizzes = quiz_machine.problem.reconfigure( - c_quizzes, ("A", "f_A", "B", "f_B") - ) - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - f"non_validated_{n_epoch:04d}_{model.id:02d}.png", - c_quizzes, - ) - - ###################################################################### - - if science_w_quizzes is not None: - result, correct = quiz_machine.predict( - model=model, - quizzes=science_w_quizzes.to(main_device), - struct=("A", "f_A", "B", "f_B"), - mask=(0, 0, 0, 1), - ) - - nb_correct = (correct == 1).long().sum() - nb_total = (correct != 0).long().sum() - log_string( - f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" - ) + save_additional_results() ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index 8e40921..a9319c7 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -132,7 +132,10 @@ class QuizMachine: self.train_struct = [ ("A", "f_A", "B", "f_B"), # The standard order ("f_A", "A", "f_B", "B"), # The reverse order for validation + ("B", "f_B", "A", "f_A"), + ("f_B", "B", "f_A", "A"), ("f_B", "f_A", "A", "B"), # The synthesis order + ("f_B", "f_A", "A", "B"), # twice! ] self.LOCK_C_QUIZZES = threading.Lock() @@ -224,6 +227,8 @@ class QuizMachine: for struct, mask in [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), ]: i = self.problem.indices_select(quizzes=input, struct=struct) @@ -490,3 +495,33 @@ class QuizMachine: return c_quizzes.to("cpu") ###################################################################### + + def generate_c_quizzes_simple( + self, + nb, + model_for_generation, + temperature_hot=1.0, + temperature_cold=1.0, + ): + c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) + c_quizzes = c_quizzes.to(self.device) + + seq_logproba = torch.zeros(nb, device=self.device) + + lt_noisy = lambda s, logits: logits / temperature_hot + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.make_ar_mask( + c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 1, 1) + ), + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + device=self.device, + ) + + return c_quizzes.to("cpu") + + ######################################################################