From: François Fleuret Date: Tue, 17 Sep 2024 12:05:17 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=99d761505e3dc6a25ebb1c5939360af0f4509416;p=culture.git Update. --- diff --git a/main.py b/main.py index ce86a76..a353868 100755 --- a/main.py +++ b/main.py @@ -431,8 +431,10 @@ def predict(model, imt_set, local_device=main_device): desc="predict", total=imt_set.size(0) // args.physical_batch_size, ): - masks = imt[:, 1] - imt = imt * (1 - masks[:, None]) # paranoia + # some paranoia + imt = imt.clone() + imt[:, 0] = imt[:, 0] * (1 - imt[:1]) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(imt[:, 0] * 2 + imt[:, 1]) dist = torch.distributions.categorical.Categorical(logits=logits) @@ -494,7 +496,7 @@ def generate(model, nb, local_device=main_device): changed = True for it in range(args.diffusion_nb_iterations): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(input) + logits = model(input * 2 + masks) dist = torch.distributions.categorical.Categorical(logits=logits) output = dist.sample() @@ -507,7 +509,7 @@ def generate(model, nb, local_device=main_device): changed = changed & (update != input).max(dim=1).values input[changed] = update[changed] - return input + return all_input ###################################################################### @@ -563,7 +565,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}") -def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): +###################################################################### + + +def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): # train one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True) @@ -575,11 +580,13 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): imt_set = IMT_batch_prediction(quizzes.to(local_device)) result = predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") + correct = (quizzes == result).min(dim=1).values.long() correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[ :, :, 1 ] predicted_parts = correct_parts.abs() + quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"culture_prediction_{n_epoch}_{model.id}.png", @@ -1063,7 +1070,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], multithread_execution( - one_train_test_epoch, + one_complete_epoch, [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)], )