From a3211f96c7426a613b82a2de87d4dd70640e8f46 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 22:23:03 +0200 Subject: [PATCH] Update. --- main.py | 2 +- tasks.py | 42 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 213524e..e3fd9f0 100755 --- a/main.py +++ b/main.py @@ -266,7 +266,7 @@ picoclvr_pruner_eval = ( if args.task == "sandbox": task = tasks.SandBox( - tasks.ProblemLevel1(), + tasks.ProblemLevel2(), # tasks.ProblemAddition(zero_padded=False, inverted_result=False), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, diff --git a/tasks.py b/tasks.py index 706e1d9..73f61bf 100755 --- a/tasks.py +++ b/tasks.py @@ -96,18 +96,19 @@ class ProblemLevel1(Problem): num_classes=len_source, ) - - def generate_sequences(self, nb): nb_operators = torch.randint(self.operators.size(0), (nb,)) operators = self.operators[nb_operators] - nb_operators = (nb_operators[:, None] // 10 ** torch.arange(self.len_nb_operator-1,-1,-1)) % 10 - marker1 = torch.full((nb,1),10) + nb_operators = ( + nb_operators[:, None] + // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) + ) % 10 + marker1 = torch.full((nb, 1), 10) source = torch.randint(10, (nb, self.len_source)) - marker2 = torch.full((nb,1),11) + marker2 = torch.full((nb, 1), 11) result = operators.bmm(source[:, :, None]).squeeze(-1) print(f"{nb_operators.dtype=} {marker1.dtype=}") - sequences = torch.cat((nb_operators, marker1, source,marker2,result),1) + sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1) print(f"{sequences.size()=}") ar_mask = (sequences == 11).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) @@ -117,6 +118,35 @@ class ProblemLevel1(Problem): return "".join("0123456789|>"[x.item()] for x in seq) +class ProblemLevel2(Problem): + def __init__(self, len_source=5, len_result=8): + self.len_source = len_source + self.len_result = len_result + + def generate_sequences(self, nb): + operators = F.one_hot( + torch.rand(nb, self.len_result, self.len_source).argmax(-1), + num_classes=self.len_source, + ) + source1 = torch.randint(10, (nb, self.len_source)) + marker1 = torch.full((nb, 1), 10) + result1 = operators.bmm(source1[:, :, None]).squeeze(-1) + marker2 = torch.full((nb, 1), 11) + source2 = torch.randint(10, (nb, self.len_source)) + marker3 = torch.full((nb, 1), 12) + result2 = operators.bmm(source2[:, :, None]).squeeze(-1) + + sequences = torch.cat( + (source1, marker1, result1, marker2, source2, marker3, result2), 1 + ) + ar_mask = (sequences == 12).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789>|~"[x.item()] for x in seq) + + #################### -- 2.39.5