From 569f0204caea670d80bd664372543b5a3bb35997 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 27 Aug 2024 21:33:46 +0200 Subject: [PATCH] Update. --- main.py | 659 ++++++++------------------------------------------------ 1 file changed, 93 insertions(+), 566 deletions(-) diff --git a/main.py b/main.py index e4237d2..85213cb 100755 --- a/main.py +++ b/main.py @@ -354,7 +354,7 @@ def run_tests(model, quiz_machine, local_device=main_device): mask_loss = mask_loss.to(local_device) targets = input - output = model(mygpt.BracketedSequence(input)).x + output = model(input) loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) @@ -407,26 +407,11 @@ def one_epoch(model, quiz_machine, local_device=main_device): model.optimizer.zero_grad() targets = input - - output = model(mygpt.BracketedSequence(input)).x - - loss_per_token = F.cross_entropy( - output.transpose(1, 2), targets, reduction="none" - ) - - # warnings.warn("entropy masking", RuntimeWarning) - # l = output.transpose(1, 2).log_softmax(dim=1) - # H = -(l * l.exp()).sum(dim=1) - # M = (H >= -math.log(0.99) / H.size(1)).long() - # print(H, M) - # loss_per_token = loss_per_token * M - - loss = (loss_per_token * mask_loss).mean() + model.loss + output = model(input) + loss = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") + loss = (loss * mask_loss).mean() + model.loss acc_train_loss += loss.item() * input.size(0) - - loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - nb_train_samples += input.size(0) loss.backward() @@ -477,28 +462,6 @@ data_structures = [ ###################################################################### -def model_proba_solutions(model, quizzes): - l = ( - quiz_machine.models_logprobas( - model, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - ) - - return l.exp() - - -###################################################################### - - def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate @@ -853,6 +816,13 @@ class MyAttentionAE(nn.Module): return bs +###################################################################### + +nb_iterations = 25 +probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device) +probs_iterations = probs_iterations[None, :] / probs_iterations.sum() + + def ae_batches( quiz_machine, nb, @@ -952,6 +922,27 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): return input +###################################################################### + + +def model_ae_proba_solutions(model, input): + loss = 0 + + for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=input, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + ) + targets, logits = targets_and_prediction( + probs_iterations, model, input, mask_generate + ) + loss_per_token = F.cross_entropy( + logits.transpose(1, 2), targets, reduction="none" + ) + loss += (loss_per_token * mask_generate).sum(dim=1) + + return (-loss).exp() + + def degrade_input(input, mask_generate, nb_iterations, noise_proba): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device @@ -970,13 +961,31 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba): return result -###################################################################### +def targets_and_prediction(probs_iterations, model, input, mask_generate): + d = deterministic(mask_generate) + p = probs_iterations.expand(input.size(0), -1) + dist = torch.distributions.categorical.Categorical(probs=p) + N0 = dist.sample() + N1 = N0 + 1 + N0 = (1 - d) * N0 + N1 = (1 - d) * N1 + d * nb_iterations + + targets, input = degrade_input( + input, mask_generate, (0 * N1, N1), noise_proba=noise_proba + ) + input_with_mask = NTC_channel_cat(input, mask_generate) + logits = model(input_with_mask) -def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): + return targets, logits + + +def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_device): with torch.autograd.no_grad(): model.eval().to(local_device) + # Compute the loss + nb_test_samples, acc_test_loss = 0, 0.0 for input, mask_generate, mask_loss in ae_batches( @@ -986,18 +995,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): local_device, "test", ): - d = deterministic(mask_generate) - p = probs_iterations.expand(input.size(0), -1) - dist = torch.distributions.categorical.Categorical(probs=p) - N0 = dist.sample() - N1 = N0 + 1 - N0 = (1 - d) * N0 - N1 = (1 - d) * N1 + d * nb_iterations - targets, input = degrade_input( - input, mask_generate, (0 * N1, N1), noise_proba=noise_proba + targets, logits = targets_and_prediction( + probs_iterations, model, input, mask_generate ) - input_with_mask = NTC_channel_cat(input, mask_generate) - logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -1006,10 +1006,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): f"test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}" ) - # ------------------------------------------- - # Test generation + # Compute the accuracy and save some images - nb_correct, nb_total, record = 0, 0, [] + nb_correct, nb_total, record_d, record_nd = 0, 0, [], [] for input, mask_generate, mask_loss in ae_batches( quiz_machine, @@ -1026,12 +1025,14 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ :, :, 1 ] - solution_is_deterministic = predicted_parts.sum(dim=-1) == 1 - correct = (2 * correct - 1) * (solution_is_deterministic).long() + d = predicted_parts.sum(dim=-1) == 1 + correct = (2 * correct - 1) * d.long() nb_correct += (correct == 1).long().sum() nb_total += (correct != 0).long().sum() correct_parts = predicted_parts * correct[:, None] - record.append((result, predicted_parts, correct_parts)) + record_d.append((result[d], predicted_parts[d], correct_parts[d])) + nd = d == False + record_nd.append((result[nd], predicted_parts[nd], correct_parts[nd])) log_string( f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" @@ -1039,27 +1040,34 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): model.test_accuracy = nb_correct / nb_total - filename = f"prediction_ae_{n_epoch:04d}.png" + for f, record in [("prediction", record_d), ("generative", record_nd)]: + filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + result, predicted_parts, correct_parts = ( + torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2] + ) - result, predicted_parts, correct_parts = ( - torch.cat([x[i] for x in record]) for i in [0, 1, 2] - ) + l = [model_ae_proba_solutions(model, result) for model in other_models] + probas = torch.cat([x[:, None] for x in l], dim=1) + comments = [] - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - ) + for l in probas: + comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - log_string(f"wrote {filename}") + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, + comments=comments, + ) + log_string(f"wrote {filename}") ###################################################################### -def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device): +def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_device): model.train().to(local_device) nb_train_samples, acc_train_loss = 0, 0.0 @@ -1074,20 +1082,9 @@ def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device): if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - d = deterministic(mask_generate) - p = probs_iterations.expand(input.size(0), -1) - dist = torch.distributions.categorical.Categorical(probs=p) - N0 = dist.sample() - N1 = N0 + 1 - N0 = (1 - d) * N0 - N1 = (1 - d) * N1 + d * nb_iterations - - targets, input = degrade_input( - input, mask_generate, (0 * N1, N1), noise_proba=noise_proba + targets, logits = targets_and_prediction( + probs_iterations, model, input, mask_generate ) - - input_with_mask = NTC_channel_cat(input, mask_generate) - logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -1101,17 +1098,13 @@ def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device): f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, quiz_machine, n_epoch, local_device=local_device) + run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device) ###################################################################### noise_proba = 0.05 -nb_iterations = 25 -probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device) -probs_iterations = probs_iterations[None, :] / probs_iterations.sum() - models = [] for i in range(args.nb_models): @@ -1194,6 +1187,9 @@ for n_epoch in range(args.nb_epochs): # -------------------------------------------------------------------- + one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device) + exit(0) + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1202,18 +1198,20 @@ for n_epoch in range(args.nb_epochs): start_time = time.perf_counter() for gpu, model in zip(gpus, weakest_models): - log_string(f"training model {model.id}") + log_string(f"training model {model.id} (accuracy {model.test_accuracy})") t = threading.Thread( - target=one_ae_epoch, daemon=True, args=(model, quiz_machine, n_epoch, gpu) + target=one_ae_epoch, + daemon=True, + args=(model, models, quiz_machine, n_epoch, gpu), ) threads.append(t) t.start() - for t in threads: - t.join() + for t in threads: + t.join() # -------------------------------------------------------------------- @@ -1231,476 +1229,5 @@ for n_epoch in range(args.nb_epochs): }, os.path.join(args.result_dir, filename), ) - log_string(f"wrote {filename}") - -###################################################################### - - -def create_models(): - models = [] - - def compute_causal_attzero(t_q, t_k): - return t_q < t_k - - for k in range(args.nb_models): - log_string(f"creating model {k}") - - model = mygpt.MyGPT( - vocabulary_size=vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - compute_attzero=compute_causal_attzero, - dropout=args.dropout, - ).to(main_device) - - class UpperBoundStd(nn.Module): - def __init__(self, std_max=1.0): - super().__init__() - self.std_max = std_max - - def forward(self, x): - std = x.std(dim=-1, keepdim=True) - y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max) - return y - - if args.logit_std_max > 0: - model.readout.f = nn.Sequential( - model.readout.f, UpperBoundStd(std_max=args.logit_std_max) - ) - - model.id = k - model.train_c_quiz_bags = [] - model.test_c_quiz_bags = [] - - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - model.test_accuracy = 0.0 - model.gen_test_accuracy = 0.0 - model.gen_state_dict = copy.deepcopy(model.state_dict()) - models.append(model) - - return models - - -common_c_quiz_bags = [] - -models = create_models() - -###################################################################### - -# We balance the computing time between training the models and -# generating c_quizzes - -total_time_generating_c_quizzes = 0 -total_time_training_models = 0 - -current_epoch = 0 - -if args.resume: - for model in models: - filename = f"gpt_{model.id:03d}.pth" - - try: - d = torch.load(os.path.join(args.result_dir, filename)) - model.load_state_dict(d["state_dict"]) - model.optimizer.load_state_dict(d["optimizer_state_dict"]) - model.test_accuracy = d["test_accuracy"] - model.gen_test_accuracy = d["gen_test_accuracy"] - model.gen_state_dict = d["gen_state_dict"] - model.train_c_quiz_bags = d["train_c_quiz_bags"] - model.test_c_quiz_bags = d["test_c_quiz_bags"] - log_string(f"successfully loaded {filename}") - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass - - try: - filename = "state.pth" - state = torch.load(os.path.join(args.result_dir, filename)) - log_string(f"successfully loaded {filename}") - current_epoch = state["current_epoch"] - total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] - total_time_training_models = state["total_time_training_models"] - common_c_quiz_bags = state["common_c_quiz_bags"] - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass - -###################################################################### - -nb_parameters = sum(p.numel() for p in models[0].parameters()) -log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") - -###################################################################### - -if args.nb_new_c_quizzes_for_train is None: - args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50 - -if args.nb_new_c_quizzes_for_test is None: - args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50 - -log_string( - f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" -) - -###################################################################### - -if args.dirty_debug: - args.accuracy_to_make_c_quizzes = 0.0 - args.nb_models = 2 - args.nb_new_c_quizzes_for_train = 100 - args.nb_new_c_quizzes_for_test = 10 - -###################################################################### - - -class Folder(nn.Module): - def forward(self, x): - return x.mean(dim=1) - - -class Unfolder(nn.Module): - def __init__(self, T, dim): - super().__init__() - self.biases = nn.Parameter(torch.randn(T, dim)) - - def forward(self, x): - return x[:, None, :] + self.biases[None, :, :] - - -class Recorder(nn.Module): - def __init__(self, tape): - super().__init__() - self.tape = tape - - def forward(self, input): - self.tape.append(input) - return input - - -if args.test == "mlp": - model = models[0] - tape_input, tape_output = [], [] - L = len(model.trunk) - model.trunk.insert(L // 2 + 1, Recorder(tape_output)) - model.trunk.insert(L // 2, Recorder(tape_input)) - - mlp = nn.Sequential( - nn.Linear(args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, 8 * args.dim_model), - Folder(), - Unfolder(404, 8 * args.dim_model), - nn.Linear(8 * args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, args.dim_model), - ).to(main_device) - - mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4) - - for n_epoch in range(args.nb_epochs): - train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) - - tape_input.clear() - tape_output.clear() - - with torch.autograd.no_grad(): - model.to(main_device).eval() - for input in train_input.split(args.batch_size): - input = input.to(main_device) - output = model(mygpt.BracketedSequence(input)).x - - train_input = torch.cat([bs.x for bs in tape_input], dim=0) - train_targets = torch.cat([bs.x for bs in tape_output], dim=0) - - nb_train_samples, acc_train_loss = 0, 0.0 - src = zip( - train_input.split(args.batch_size), train_targets.split(args.batch_size) - ) - for input, targets in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="train", - total=train_input.size(0) // args.batch_size, - ): - input = input.to(main_device) - output = mlp(input) - loss = F.mse_loss(output, targets) + output.abs().sum() - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) - - mlp.optimizer.zero_grad() - loss.backward() - mlp.optimizer.step() - - log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}") - - exit(0) - -###################################################################### - - -def save_generated_c_quizzes(model, filename, nb=64): - while sum([x.size(0) for x in record]) < nb: - model = models[torch.randint(len(models), (1,)).item()] - c_quizzes = quiz_machine.generate_c_quizzes( - 64, - model_for_generation=model, - procedure=c_quizzes_procedure, - ) - - p = quiz_machine.models_logprobas( - model, - c_quizzes, - ("A", "f_A", "B", "f_B"), - (1, 1, 1, 1), - temperature=1, - ).exp() - - p_hot = quiz_machine.models_logprobas( - model, - c_quizzes, - ("A", "f_A", "B", "f_B"), - (1, 1, 1, 1), - temperature=args.temperature_hot, - ).exp() - - to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p - record.append(c_quizzes[to_keep]) - - print("NB_KEPT", sum([x.size(0) for x in record])) - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=c_quizzes, - ) - - log_string(f"wrote {filename}") - - -###################################################################### - - -if args.test == "entropy": - model = models[0] - model.to(main_device) - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - - log_string("starting testing entropy maximization") - - for n_epoch in range(100): - input = quiz_machine.generate_c_quizzes( - 128, - model_for_generation=model, - procedure=c_quizzes_procedure, - ) - - filename = f"test_{n_epoch:04d}.png" - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=input, - ) - - log_string(f"wrote {filename}") - - with torch.no_grad(): - for p in model.parameters(): - p += torch.randn(p.size(), device=p.device) * 1e-3 - - # nb_train_samples, acc_train_loss = 0, 0.0 - - # for k in range(1000 // args.batch_size): - # input = quiz_machine.generate_c_quizzes( - # args.batch_size, - # model_for_generation=model, - # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)], - # ) - - # input = input.to(main_device) - # targets = input - # output = model(mygpt.BracketedSequence(input)).x - # loss = -F.cross_entropy(output.transpose(1, 2), targets) - # acc_train_loss += loss.item() * input.size(0) - # nb_train_samples += input.size(0) - - # optimizer.zero_grad() - # loss.backward() - # optimizer.step() - - # log_string( - # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}" - # ) - - exit(0) - -###################################################################### - -for n_epoch in range(current_epoch, args.nb_epochs): - state = { - "current_epoch": n_epoch, - "total_time_generating_c_quizzes": total_time_generating_c_quizzes, - "total_time_training_models": total_time_training_models, - "common_c_quiz_bags": common_c_quiz_bags, - } - filename = "state.pth" - torch.save(state, os.path.join(args.result_dir, filename)) - log_string(f"wrote {filename}") - - log_string(f"--- epoch {n_epoch} ----------------------------------------") - - cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models]) - log_string(f"current_test_accuracies {cta}") - - cta = " ".join([f"{float(m.gen_test_accuracy):.04f}" for m in models]) - log_string(f"current_gen_test_accuracies {cta}") - - ################################################## - - for model in models: - if model.test_accuracy >= args.accuracy_to_make_c_quizzes: - log_string( - f"storing_gen model {model.id} accuracy {model.gen_test_accuracy} -> {model.test_accuracy}" - ) - model.gen_state_dict = copy.deepcopy(model.state_dict()) - model.gen_test_accuracy = model.test_accuracy - - # we restart - if total_time_generating_c_quizzes == 0: - total_time_training_models = 0 - - if min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: - if args.reboot: - for model in models: - model.current_dict = copy.deepcopy(model.state_dict()) - model.load_state_dict(model.gen_state_dict) - - while True: - record_new_c_quizzes( - models, - quiz_machine, - args.nb_new_c_quizzes_for_train, - args.nb_new_c_quizzes_for_test, - ) - - nb_c_quizzes_per_model = [ - sum([x.size(0) for x in model.train_c_quiz_bags]) - for model in models - ] - - p = tuple( - f"{(x*100)/args.nb_train_samples:.02f}%" - for x in nb_c_quizzes_per_model - ) - - log_string(f"nb_c_quizzes_per_model {p}") - - m = max(nb_c_quizzes_per_model) - - if m * args.c_quiz_multiplier >= args.nb_train_samples: - break - - model = models[nb_c_quizzes_per_model.index(m)] - common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0)) - nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags]) - log_string( - f"rebooting the models with {nb_common_c_quizzes} culture quizzes" - ) - - models = create_models() - total_time_generating_c_quizzes = 0 - total_time_training_models = 0 - - elif total_time_training_models >= total_time_generating_c_quizzes: - for model in models: - model.current_dict = copy.deepcopy(model.state_dict()) - model.load_state_dict(model.gen_state_dict) - - start_time = time.perf_counter() - - record_new_c_quizzes( - models, - quiz_machine, - args.nb_new_c_quizzes_for_train, - args.nb_new_c_quizzes_for_test, - ) - - total_time_generating_c_quizzes += time.perf_counter() - start_time - - for model in models: - model.load_state_dict(model.current_dict) - - ################################################## - # Select, improve, and eval the worst model(s) - - if total_time_training_models <= total_time_generating_c_quizzes: - ranked_models = sorted( - models, - # This ugly recipe will pick the worst if there some below - # args.accuracy_to_make_c_quizzes or one at random if they - # are all above - key=lambda m: float( - m.test_accuracy - if m.test_accuracy < args.accuracy_to_make_c_quizzes - else args.accuracy_to_make_c_quizzes + torch.rand(1).item() - ), - ) - - weakest_models = ranked_models[: len(gpus)] - - threads = [] - - start_time = time.perf_counter() - - for gpu, model in zip(gpus, weakest_models): - log_string(f"training model {model.id}") - - t = threading.Thread( - target=one_epoch, daemon=True, args=(model, quiz_machine, gpu) - ) - - threads.append(t) - - t.start() - for t in threads: - t.join() - - total_time_training_models += time.perf_counter() - start_time - - # Save the models to disk - - for model in models: - filename = f"gpt_{model.id:03d}.pth" - torch.save( - { - "state_dict": model.state_dict(), - "optimizer_state_dict": model.optimizer.state_dict(), - "test_accuracy": model.test_accuracy, - "gen_test_accuracy": model.gen_test_accuracy, - "gen_state_dict": model.gen_state_dict, - "train_c_quiz_bags": model.train_c_quiz_bags, - "test_c_quiz_bags": model.test_c_quiz_bags, - }, - os.path.join(args.result_dir, filename), - ) log_string(f"wrote {filename}") - - ###################################################################### - - if args.log_command is not None: - s = args.log_command.split() - s.insert(1, args.result_dir) - subprocess.run(s) - -###################################################################### -- 2.39.5