From: François Fleuret Date: Thu, 25 Jul 2024 04:10:02 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=015babd7d91e3d514a24c980a74a171ef6a1d185;p=culture.git Update. --- diff --git a/grids.py b/grids.py index f6129e9..93b027a 100755 --- a/grids.py +++ b/grids.py @@ -137,6 +137,7 @@ class Grids(problem.Problem): self.check_structure(quizzes, struct) return struct + # What a mess def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): if torch.is_tensor(quizzes): return self.reconfigure([quizzes], struct=struct)[0] @@ -165,11 +166,11 @@ class Grids(problem.Problem): return result - def non_trivial(self, quizzes): + def trivial(self, quizzes): S = self.height * self.width assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B")) a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] - return (a[:, 0] == a[:, 1]).min(dim=1).values & (a[:, 2] == a[:, 3]).min( + return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min( dim=1 ).values diff --git a/main.py b/main.py index fa33b4e..257f40f 100755 --- a/main.py +++ b/main.py @@ -454,7 +454,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # We discard the trivial ones, according to a criterion # specific to the world quizzes (e.g. B=f(B)) - c_quizzes = c_quizzes[quiz_machine.problem.non_trivial(c_quizzes)] + c_quizzes = c_quizzes[quiz_machine.problem.trivial(c_quizzes) == False] # We go through nb_rounds rounds and keep only quizzes on # which diff --git a/quiz_machine.py b/quiz_machine.py index 4615e3a..2ca584e 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -270,14 +270,12 @@ class QuizMachine: ###################################################################### - def randomize_configuations_inplace(self, quizzes, configurations): - r = torch.randint( - len(configurations), (quizzes.size(0),), device=quizzes.device - ) + def randomize_configuations_inplace(self, quizzes, structs): + r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device) - for c in range(len(configurations)): + for c in range(len(structs)): quizzes[r == c] = self.problem.reconfigure( - quizzes[r == c], struct=configurations[c] + quizzes[r == c], struct=structs[c] ) def create_w_quizzes(self, model, nb_train_samples, nb_test_samples): @@ -285,11 +283,11 @@ class QuizMachine: model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples) self.randomize_configuations_inplace( - model.train_w_quizzes, configurations=self.train_struct + model.train_w_quizzes, structs=self.train_struct ) self.randomize_configuations_inplace( - model.test_w_quizzes, configurations=self.train_struct + model.test_w_quizzes, structs=self.train_struct ) ###################################################################### @@ -322,7 +320,7 @@ class QuizMachine: ) self.randomize_configuations_inplace( - model.train_w_quizzes, configurations=self.train_struct + model.train_w_quizzes, structs=self.train_struct ) ######################################################################