From: François Fleuret Date: Fri, 21 Jun 2024 08:35:55 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0ab695df8f6a2a0cc70a424e57943a0d5606903b;p=culture.git Update. --- diff --git a/main.py b/main.py index d92c4a5..18b19db 100755 --- a/main.py +++ b/main.py @@ -859,7 +859,7 @@ def one_epoch(model, task, learning_rate): ###################################################################### -def run_tests(model, task): +def run_tests(model, task, deterministic_synthesis): with torch.autograd.no_grad(): model.eval() @@ -883,7 +883,7 @@ def run_tests(model, task): model=model, result_dir=args.result_dir, logger=log_string, - deterministic_synthesis=args.deterministic_synthesis, + deterministic_synthesis=deterministic_synthesis, ) test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) @@ -897,7 +897,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): one_epoch(model, task, learning_rate) - run_tests(model, task) + run_tests(model, task, deterministic_synthesis=True) + + # -------------------------------------------- time_current_result = datetime.datetime.now() if time_pred_result is not None: @@ -906,6 +908,8 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): ) time_pred_result = time_current_result + # -------------------------------------------- + checkpoint = { "nb_epochs_finished": n_epoch + 1, "model_state": model.state_dict(), diff --git a/world.py b/world.py index ac201e7..97c7b1d 100755 --- a/world.py +++ b/world.py @@ -34,7 +34,7 @@ def generate( nb, height, width, - max_nb_obj=len(colors) - 2, + max_nb_obj=colors.size(0) - 2, nb_iterations=2, ): f_start = torch.zeros(nb, height, width, dtype=torch.int64) @@ -43,7 +43,7 @@ def generate( for n in range(nb): nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1 - for c in range(nb_fish): + for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values: i, j = ( torch.randint(height - 2, (1,))[0] + 1, torch.randint(width - 2, (1,))[0] + 1, @@ -117,7 +117,7 @@ if __name__ == "__main__": height, width = 6, 8 start_time = time.perf_counter() - seq = generate(nb=64, height=height, width=width) + seq = generate(nb=64, height=height, width=width, max_nb_obj=3) delay = time.perf_counter() - start_time print(f"{seq.size(0)/delay:02f} samples/s")