From a291e213a152364b74e833200191c08a36451a90 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 17:51:03 +0200 Subject: [PATCH] Update. --- tasks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tasks.py b/tasks.py index 0f44760..0a4dd6f 100755 --- a/tasks.py +++ b/tasks.py @@ -1042,7 +1042,7 @@ class RPL(Task): ) ], 0, - ).to(self.device) + ) def seq2str(self, seq): return " ".join([self.id2token[i] for i in seq]) @@ -1101,7 +1101,7 @@ class RPL(Task): self.test_input = self.tensorize(test_sequences) if logger is not None: - for x in self.train_input[:10]: + for x in self.train_input[:25]: end = (x != self.t_nul).nonzero().max().item() + 1 seq = [self.id2token[i.item()] for i in x[:end]] s = " ".join(seq) @@ -1120,7 +1120,7 @@ class RPL(Task): input.split(self.batch_size), dynamic_ncols=True, desc=desc ): last = (batch != self.t_nul).max(0).values.nonzero().max() + 3 - batch = batch[:, :last] + batch = batch[:, :last].to(self.device) yield batch def vocabulary_size(self): @@ -1129,6 +1129,7 @@ class RPL(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): + # -------------------------------------------------------------------- def compute_nb_errors(input, nb_to_log=0): result = input.clone() s = (result == self.t_prog).long() @@ -1169,8 +1170,10 @@ class RPL(Task): return sum_nb_total, sum_nb_errors + # -------------------------------------------------------------------- + test_nb_total, test_nb_errors = compute_nb_errors( - self.test_input[:1000], nb_to_log=10 + self.test_input[:1000].to(self.device), nb_to_log=10 ) logger( -- 2.39.5