From: François Fleuret Date: Sat, 8 Jul 2023 09:17:25 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=45a3c70758eb867106537ff7c20491bc32ef5f1e;p=picoclvr.git Update. --- diff --git a/tasks.py b/tasks.py index 912b405..8fe89be 100755 --- a/tasks.py +++ b/tasks.py @@ -748,18 +748,21 @@ class Stack(Task): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - for n in range(result.size(0)): - logger( - f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - ) - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + + # for n in range(result.size(0)): + # logger( + # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + # ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + for n in range(result.size(0)): logger( f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" @@ -936,16 +939,19 @@ class Expr(Task): result = input.clone() ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) result = (1 - ar_mask) * result + ar_mask * self.filler - for n in range(result.size(0)): - logger(f"test_before {self.seq2str(result[n])}") - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + + # for n in range(result.size(0)): + # logger(f"test_before {self.seq2str(result[n])}") + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + correct = (1 - ar_mask) * self.space + ar_mask * input for n in range(result.size(0)): comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""