parser.add_argument("--dirty_debug", action="store_true", default=False)
-parser.add_argument("--autoencoder_dim", type=int, default=-1)
+parser.add_argument("--test_generator", action="store_true", default=False)
######################################################################
######################################################################
-# DIRTY TEST
+if args.test_generator:
+ token_prolog_0 = vocabulary_size + 0
+ token_prolog_1 = vocabulary_size + 1
+ token_prolog_2 = vocabulary_size + 2
+ generator_vocabulary_size = vocabulary_size + 3
-# train_complexifier(models[0], models[1], models[2])
+ generator = mygpt.MyGPT(
+ vocabulary_size=generator_vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ causal=True,
+ dropout=args.dropout,
+ ).to(main_device)
-# exit(0)
+ generator.main_test_accuracy = 0.0
-######################################################################
+ for n_epoch in range(args.nb_epochs):
+ one_generator_epoch(
+ generator,
+ quiz_machine=quiz_machine,
+ models=models,
+ w_quizzes=True,
+ local_device=main_device,
+ )
-token_prolog_0 = vocabulary_size + 0
-token_prolog_1 = vocabulary_size + 1
-token_prolog_2 = vocabulary_size + 2
-generator_vocabulary_size = vocabulary_size + 3
-
-generator = mygpt.MyGPT(
- vocabulary_size=generator_vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- causal=True,
- dropout=args.dropout,
-).to(main_device)
-
-generator.main_test_accuracy = 0.0
-
-for n_epoch in range(25):
- one_generator_epoch(
- generator,
- quiz_machine=quiz_machine,
- models=models,
- w_quizzes=True,
- local_device=main_device,
- )
+ filename = f"generator.pth"
+ torch.save(
+ (generator.state_dict(), generator.main_test_accuracy),
+ os.path.join(args.result_dir, filename),
+ )
+ log_string(f"wrote {filename}")
- c_quizzes = generate_c_quizz_with_generator(
- generator, quiz_machine, args.batch_size
- )
+ c_quizzes = generate_c_quizz_with_generator(
+ generator, quiz_machine, args.batch_size
+ )
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
- )
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+ )
- print(seq_logproba.exp())
+ print(seq_logproba.exp())
+ comments = []
-one_generator_epoch(
- generator,
- quiz_machine=quiz_machine,
- models=models,
- w_quizzes=False,
- local_device=main_device,
-)
+ for l in seq_logproba:
+ comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+
+ filename = f"generator_batch_{n_epoch:04d}.png"
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, filename, c_quizzes, comments=comments
+ )
+ log_string(f"wrote {filename}")
-exit(0)
+ one_generator_epoch(
+ generator,
+ quiz_machine=quiz_machine,
+ models=models,
+ w_quizzes=False,
+ local_device=main_device,
+ )
+ exit(0)
######################################################################