From: François Fleuret Date: Thu, 5 Sep 2024 15:28:25 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f1fd44ffc08596aa51a323c64a54d0bb79259c0c;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 2717b22..9e80f62 100755 --- a/grids.py +++ b/grids.py @@ -284,11 +284,14 @@ class Grids(problem.Problem): self.cache_rec_coo = {} all_tasks = [ + ############################################ fundamental ones self.task_replace_color, self.task_translate, self.task_grow, - self.task_half_fill, self.task_frame, + ############################################ + ############################################ + self.task_half_fill, self.task_detect, self.task_scale, self.task_symbols, @@ -700,6 +703,27 @@ class Grids(problem.Problem): X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] + # @torch.compile + def task_symmetry(self, A, f_A, B, f_B): + a, b = torch.randint(2, (2,)) + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + X[: self.height // 2] = c[-1] + f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2] + if a == 1: + X[...] = X.clone().t() + f_X[...] = f_X.clone().t() + if b == 1: + Z = X.clone() + X[...] = f_X + f_X[...] = Z + # @torch.compile def task_translate(self, A, f_A, B, f_B): while True: @@ -1812,7 +1836,7 @@ if __name__ == "__main__": # for t in grids.all_tasks: - for t in [grids.task_recworld_immobile]: + for t in [grids.task_symmetry]: print(t.__name__) w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) grids.save_quizzes_as_image( diff --git a/main.py b/main.py index 174b9b8..8e938db 100755 --- a/main.py +++ b/main.py @@ -284,6 +284,24 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 +# ------------------------------------------------------ +alien_problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks="symmetry", +) + +alien_quiz_machine = quiz_machine.QuizMachine( + problem=alien_problem, + batch_size=args.inference_batch_size, + result_dir=args.result_dir, + logger=log_string, + device=main_device, +) + +# ------------------------------------------------------ + ###################################################################### problem = grids.Grids( @@ -918,7 +936,15 @@ def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0): return targets, logits -def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device): +###################################################################### + + +def run_ae_test( + model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None +): + if prefix is not None: + prefix = prefix + "_" + with torch.autograd.no_grad(): model.eval().to(local_device) @@ -940,7 +966,7 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_ nb_test_samples += input.size(0) log_string( - f"test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}" + f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}" ) # Compute the accuracy and save some images @@ -975,15 +1001,16 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_ record_nd.append((result[nd], predicted_parts[nd], correct_parts[nd])) log_string( - f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" + f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" ) - model.test_accuracy = nb_correct / nb_total + if prefix is None: + model.test_accuracy = nb_correct / nb_total # Save some images for f, record in [("prediction", record_d), ("generation", record_nd)]: - filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" result, predicted_parts, correct_parts = bag_to_tensors(record) @@ -1366,6 +1393,17 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- + # run_ae_test( + # model, + # alien_quiz_machine, + # n_epoch, + # c_quizzes=None, + # local_device=main_device, + # prefix="alien", + # ) + + # exit(0) + # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device) # exit(0)