From: François Fleuret Date: Tue, 23 Jul 2024 07:24:07 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Finv;p=culture.git Update. --- diff --git a/main.py b/main.py index 4a0c1f5..61820dd 100755 --- a/main.py +++ b/main.py @@ -451,7 +451,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64) while nb_validated_per_model.sum() < nb_to_validate: - # We balance the number of quizzes per model + # We use the model that has generated the fewest quizzes to + # balance the number of quizzes per model overall model_for_generation = sorted( models, key=lambda m: nb_validated_per_model[m.id] @@ -468,29 +469,39 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 temperature_cold=args.temperature_cold, ) - # We discard the trivial ones + # We discard the trivial ones, according to a criterion + # specific to the world quizzes (e.g. B=f(B)) c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] # We go through nb_rounds rounds and keep only quizzes on - # which models respond always the same through rounds and one - # which N-1 succeed and one fails + # which + # + # (1) models respond always the same through rounds, and + # + # (2) at least one and up to max_fail_to_validate model(s) + # fail(s) - ms = 0 # "model scores" + # This is nb_quizzes x nb_models + number_correct_responses = 0 for r in range(args.nb_rounds): - ms += quiz_machine.models_successes(models, c_quizzes) - nb_sure_and_correct = (ms == r + 1).long().sum(dim=1) - nb_sure_and_fail = (ms == 0).long().sum(dim=1) + number_correct_responses += quiz_machine.models_successes(models, c_quizzes) + + nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1) + nb_sure_fail = (number_correct_responses == 0).long().sum(dim=1) + to_keep = ( - (nb_sure_and_correct + nb_sure_and_fail == ms.size(1)) - & (nb_sure_and_fail >= 1) - & (nb_sure_and_fail <= args.max_fail_to_validate) + (nb_sure_correct + nb_sure_fail == number_correct_responses.size(1)) + & (nb_sure_fail >= 1) + & (nb_sure_fail <= args.max_fail_to_validate) ) c_quizzes = c_quizzes[to_keep] - ms = ms[to_keep] - print(f"Round {r} remains {c_quizzes.size(0)}") + number_correct_responses = number_correct_responses[to_keep] + + log_string(f"round {r} remains {c_quizzes.size(0)}") + if c_quizzes.size(0) == 0: break @@ -552,6 +563,7 @@ models = [] for k in range(args.nb_gpts): log_string(f"creating model {k} and its w_quizzes") + model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -568,15 +580,8 @@ for k in range(args.nb_gpts): quiz_machine.create_w_quizzes( model=model, - nb=args.nb_train_samples, - for_train=True, - p2a_only=args.p2a_only, - ) - - quiz_machine.create_w_quizzes( - model=model, - nb=args.nb_test_samples, - for_train=False, + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, p2a_only=args.p2a_only, ) @@ -733,11 +738,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # Renew the training samples for model in weakest_models: - quiz_machine.renew_w_quizzes( - model=model, - for_train=True, - p2a_only=args.p2a_only, - ) + quiz_machine.renew_train_w_quizzes(model=model, p2a_only=args.p2a_only) if args.log_command is not None: s = args.log_command.split() diff --git a/quiz_machine.py b/quiz_machine.py index b1f6be1..d6c686e 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -357,45 +357,47 @@ class QuizMachine: ###################################################################### - def create_w_quizzes(self, model, nb, for_train=True, p2a_only=False): - input = self.generate_token_sequences(nb) + def create_w_quizzes( + self, model, nb_train_samples, nb_test_samples, p2a_only=False + ): + model.train_w_quizzes = self.generate_token_sequences(nb_train_samples) + model.test_w_quizzes = self.generate_token_sequences(nb_test_samples) if not p2a_only: - self.p_a_flip_half_in_place(input) - - if for_train: - model.train_w_quizzes = input - else: - model.test_w_quizzes = input + self.p_a_flip_half_in_place(model.train_w_quizzes) + self.p_a_flip_half_in_place(model.test_w_quizzes) ###################################################################### - def renew_w_quizzes(self, model, for_train=True, p2a_only=False): - input = model.train_w_quizzes if for_train else model.test_w_quizzes - + def renew_train_w_quizzes(self, model, p2a_only=False): if for_train and hasattr(model, "hard_w_quizzes"): self.logger( f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" ) - if model.hard_w_quizzes.size(0) >= input.size(0): - input[...] = model.hard_w_quizzes[ - torch.randperm(hard_w_quizzes.size(0))[input.size(0)] + + if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): + model.train_w_quizzes[...] = model.hard_w_quizzes[ + torch.randperm(hard_w_quizzes.size(0))[ + model.train_w_quizzes.size(0) + ] ] else: - input[...] = torch.cat( + model.train_w_quizzes[...] = torch.cat( [ model.hard_w_quizzes, self.generate_token_sequences( - input.size(0) - model.hard_w_quizzes.size(0) + model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0) ), ], dim=0, ) else: - input[...] = self.generate_token_sequences(input.size(0)) + model.train_w_quizzes[...] = self.generate_token_sequences( + model.train_w_quizzes.size(0) + ) if not p2a_only: - self.p_a_flip_half_in_place(input) + self.p_a_flip_half_in_place(model.train_w_quizzes) ######################################################################