From: François Fleuret Date: Thu, 1 Aug 2024 09:51:54 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f848e5847554870b26e6219e33c845669f4663b3;p=culture.git Update. --- diff --git a/quiz_machine.py b/quiz_machine.py index bfa7f97..a042431 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -82,11 +82,11 @@ class QuizMachine: self.prompt_noise = prompt_noise self.understood_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0)), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0)), ] self.LOCK_C_QUIZZES = threading.Lock() @@ -178,18 +178,15 @@ class QuizMachine: quizzes, from_w = quizzes[i], from_w[i] self.randomize_configuations_inplace( - quizzes, structs=[s for s, m in self.understood_structures] + quizzes, structs=[s for s, m, _ in self.understood_structures] ) if self.prompt_noise > 0.0: - for struct, mask in self.understood_structures: + for struct, mask, noise_mask in self.understood_structures: i = self.problem.indices_select(quizzes=quizzes, struct=struct) if i.any(): quizzes[i] = self.problem.inject_noise( - quizzes[i], - self.prompt_noise, - struct=struct, - mask=tuple(1 - k for k in mask), + quizzes[i], self.prompt_noise, struct=struct, mask=noise_mask ) return quizzes, from_w @@ -197,7 +194,7 @@ class QuizMachine: ###################################################################### def make_ar_mask(self, quizzes, struct, mask): - assert struct in [s for s, m in self.understood_structures] + assert struct in [s for s, _, _ in self.understood_structures] return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) ###################################################################### @@ -231,7 +228,7 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask in self.understood_structures: + for struct, mask, noise_mask in self.understood_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i] = self.predict(