From: François Fleuret Date: Fri, 21 Jun 2024 06:51:00 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0388ce599d0c60f1e3de4f796d60a3577081d22f;p=culture.git Update. --- diff --git a/main.py b/main.py index 5234d6f..d92c4a5 100755 --- a/main.py +++ b/main.py @@ -815,9 +815,10 @@ if nb_epochs_finished >= args.nb_epochs: time_pred_result = None -for n_epoch in range(nb_epochs_finished, args.nb_epochs): - learning_rate = learning_rate_schedule[n_epoch] +###################################################################### + +def one_epoch(model, task, learning_rate): log_string(f"learning_rate {learning_rate}") if args.optim == "sgd": @@ -850,6 +851,15 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): if nb_train_samples % args.batch_size == 0: optimizer.step() + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train)perplexity {n_epoch} {train_perplexity}") + + +###################################################################### + + +def run_tests(model, task): with torch.autograd.no_grad(): model.eval() @@ -868,13 +878,6 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): nb_test_samples += input.size(0) - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string( - f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" - ) - task.produce_results( n_epoch=n_epoch, model=model, @@ -883,12 +886,25 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): deterministic_synthesis=args.deterministic_synthesis, ) - time_current_result = datetime.datetime.now() - if time_pred_result is not None: - log_string( - f"next_result {time_current_result + (time_current_result - time_pred_result)}" - ) - time_pred_result = time_current_result + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + log_string(f"test)perplexity {n_epoch} {test_perplexity}") + + +###################################################################### + +for n_epoch in range(nb_epochs_finished, args.nb_epochs): + learning_rate = learning_rate_schedule[n_epoch] + + one_epoch(model, task, learning_rate) + + run_tests(model, task) + + time_current_result = datetime.datetime.now() + if time_pred_result is not None: + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) + time_pred_result = time_current_result checkpoint = { "nb_epochs_finished": n_epoch + 1,