From: François Fleuret Date: Thu, 28 Mar 2024 13:55:14 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=c3581ba868cd30cb45fbe2f97b80ddbc1fc26bbb;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index 0f2cb61..9437136 100755 --- a/main.py +++ b/main.py @@ -706,7 +706,7 @@ if args.task == "expr" and args.expr_input_file is not None: # Compute the entropy of the training tokens token_count = 0 -for input in task.batches(split="train"): +for input in task.batches(split="train", desc="train-entropy"): token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) token_probas = token_count / token_count.sum() entropy = -torch.xlogy(token_probas, token_probas).sum() @@ -728,9 +728,13 @@ if args.max_percents_of_test_in_train >= 0: yield s nb_test, nb_in_train = 0, 0 - for test_subset in subsets_as_tuples(task.batches(split="test"), 25000): + for test_subset in subsets_as_tuples( + task.batches(split="test", desc="test-check"), 25000 + ): in_train = set() - for train_subset in subsets_as_tuples(task.batches(split="train"), 25000): + for train_subset in subsets_as_tuples( + task.batches(split="train", desc="train-check"), 25000 + ): in_train.update(test_subset.intersection(train_subset)) nb_in_train += len(in_train) nb_test += len(test_subset) diff --git a/tasks.py b/tasks.py index 324376d..3ef64d7 100755 --- a/tasks.py +++ b/tasks.py @@ -1944,7 +1944,7 @@ class Greed(Task): progress_bar_desc=None, ) warnings.warn("keeping thinking snapshots", RuntimeWarning) - snapshots.append(result[:10].detach().clone()) + snapshots.append(result[:100].detach().clone()) # Generate iteration after iteration @@ -1986,11 +1986,11 @@ class Greed(Task): # Set the lookahead_reward to UNKNOWN for the next iterations result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN) + ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN) filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt") with open(filename, "w") as f: - for n in range(10): + for n in range(snapshots[0].size(0)): for s in snapshots: lr, s, a, r = self.world.seq2episodes( s[n : n + 1],