From: François Fleuret Date: Thu, 19 Sep 2024 14:14:12 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=4084de83f541242ac815eba2c0883b84fd19141e;p=culture.git Update. --- diff --git a/main.py b/main.py index 7bdd09e..c08b04d 100755 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ import threading, subprocess # import torch.multiprocessing as mp -# torch.set_float32_matmul_precision("high") +torch.set_float32_matmul_precision("high") # torch.set_default_dtype(torch.bfloat16) @@ -673,7 +673,9 @@ def evaluate_quizzes(quizzes, models, local_device): with_perturbations=True, local_device=local_device, ) + nb_mistakes = (result != quizzes).long().sum(dim=1) nb_correct += (nb_mistakes == 0).long() + result = predict_full( model=model, input=quizzes, @@ -851,7 +853,8 @@ def multithread_execution(fun, arguments): for args in arguments: # To get a different sequence between threads - log_string(f"dummy_rand {torch.rand(1)}") + # log_string(f"dummy_rand {torch.rand(1)}") + torch.rand(1) t = threading.Thread(target=threadable_fun, daemon=True, args=args) threads.append(t) t.start()