From: François Fleuret Date: Tue, 4 Jul 2023 16:08:55 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=49738bb51b386e62f86f861237cbe32b7a2ad479;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index beafc19..b907e60 100755 --- a/main.py +++ b/main.py @@ -1091,7 +1091,7 @@ class TaskExpr(Task): result = input.clone() filler, space = self.char2id["#"], self.char2id[" "] ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1) - result = (1 - ar_mask) * result + filler * ar_mask + result = (1 - ar_mask) * result + ar_mask * filler masked_inplace_autoregression( model, self.batch_size, result, ar_mask, device=self.device ) @@ -1113,16 +1113,19 @@ class TaskExpr(Task): result = input.clone() filler, space = self.char2id["#"], self.char2id[" "] ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1) - result = (1 - ar_mask) * result + filler * ar_mask + result = (1 - ar_mask) * result + ar_mask * filler for n in range(result.size(0)): s = "".join([self.id2char[k.item()] for k in result[n]]) log_string(f"test_before {s}") masked_inplace_autoregression( model, self.batch_size, result, ar_mask, device=self.device ) + correct = (1 - ar_mask) * space + ar_mask * input for n in range(result.size(0)): s = "".join([self.id2char[k.item()] for k in result[n]]) log_string(f"test_after {s}") + s = "".join([self.id2char[k.item()] for k in correct[n]]) + log_string(f"correct {s}") ############################################################## model.train(t)