From 6c8bed86221baae24a7c2aaaa41c009444efb5c9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 23 Jul 2023 09:37:33 -1000 Subject: [PATCH] Update. --- main.py | 3 ++- problems.py | 40 ++++++++++++++++++++++++++++++++++++++++ tasks.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 0f1fbb5..b9b52d6 100755 --- a/main.py +++ b/main.py @@ -355,8 +355,9 @@ if args.task == "sandbox": raise ValueError(f"Unknown sandbox level {args.sandbox_level}") task = tasks.SandBox( - problem, + # problem, # problems.ProblemAddition(zero_padded=False, inverted_result=False), + problems.ProblemLenId(len_max=args.sandbox_levels_len_source), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, diff --git a/problems.py b/problems.py index 5161587..1f4098a 100755 --- a/problems.py +++ b/problems.py @@ -21,6 +21,37 @@ class Problem: #################### +class ProblemLenId(Problem): + def __init__(self, nb_sentences=100, len_max=5): + self.len_max = len_max + + def generate_sequences(self, nb): + k = torch.arange(self.len_max * 3 + 3)[None, :] + l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1 + i = torch.randint(10, (2, nb))[:, :, None] + a = l[0] + b = l[0] + 1 + l[1] + c = l[0] + 1 + l[1] + 1 + l[0] + sequences = ( + (k < a) * i[0] + + (k == a) * 10 + + (k > a) * (k < b) * i[1] + + (k == b) * 11 + + (k > b) * (k < c) * i[1] + + (k == c) * 12 + + (k > c) * 13 + ) + ar_mask = (sequences == 11).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|>.?"[x.item()] for x in seq) + + +#################### + + class ProblemLevel0(Problem): def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) @@ -32,6 +63,12 @@ class ProblemLevel0(Problem): ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask + def seq2str(self, seq): + return "".join("0123456789|"[x.item()] for x in seq) + + +#################### + class ProblemLevel1(Problem): def __init__(self, nb_operators=100, len_source=5, len_result=8): @@ -64,6 +101,9 @@ class ProblemLevel1(Problem): return "".join("0123456789|>"[x.item()] for x in seq) +#################### + + class ProblemLevel2(Problem): def __init__(self, len_source=5, len_result=8): self.len_source = len_source diff --git a/tasks.py b/tasks.py index b2f7d7d..038a8ac 100755 --- a/tasks.py +++ b/tasks.py @@ -181,6 +181,38 @@ class SandBox(Task): f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + if save_attention_image is not None: + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() + + with torch.autograd.no_grad(): + t = model.training + model.eval() + model.record_attention(True) + model(BracketedSequence(input)) + model.train(t) + ram = model.retrieve_attention() + model.record_attention(False) + + tokens_output = [c for c in self.problem.seq2str(input[0])] + tokens_input = ["n/a"] + tokens_output[:-1] + for n_head in range(ram[0].size(1)): + filename = os.path.join( + result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf" + ) + attention_matrices = [m[0, n_head] for m in ram] + save_attention_image( + filename, + tokens_input, + tokens_output, + attention_matrices, + k_top=10, + # min_total_attention=0.9, + token_gap=12, + layer_gap=50, + ) + logger(f"wrote {filename}") + ###################################################################### -- 2.39.5