From: François Fleuret Date: Tue, 18 Jul 2023 20:35:59 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=d6f73f1d5093fb098e822e14db382dd3a1c63a2a;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index e3fd9f0..0d4930d 100755 --- a/main.py +++ b/main.py @@ -82,6 +82,17 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # picoclvr options +parser.add_argument("--sandbox_level", type=int, default=0) + +parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) + +parser.add_argument("--sandbox_levels_len_source", type=int, default=5) + +parser.add_argument("--sandbox_levels_len_result", type=int, default=8) + +############################## +# picoclvr options + parser.add_argument("--picoclvr_nb_colors", type=int, default=5) parser.add_argument("--picoclvr_height", type=int, default=12) @@ -265,8 +276,28 @@ picoclvr_pruner_eval = ( ###################################################################### if args.task == "sandbox": + if args.sandbox_level == 0: + problem = tasks.ProblemLevel0( + nb_sentences=args.sandbox_levels_nb_items, + len_prompt=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 1: + problem = tasks.ProblemLevel1( + nb_operators=args.sandbox_levels_nb_items, + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 2: + problem = tasks.ProblemLevel2( + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + else: + raise ValueError(f"Unknown sandbox level {args.sandbox_level}") + task = tasks.SandBox( - tasks.ProblemLevel2(), + problem, # tasks.ProblemAddition(zero_padded=False, inverted_result=False), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, diff --git a/tasks.py b/tasks.py index 73f61bf..e7c2f75 100755 --- a/tasks.py +++ b/tasks.py @@ -76,7 +76,7 @@ class Problem: class ProblemLevel0(Problem): def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): - self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) self.seq[:, len_prompt] = 10 def generate_sequences(self, nb):