Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 13:26:27 +0000 (15:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 13:26:27 +0000 (15:26 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index e38cbc0..edc366a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -373,12 +373,20 @@ def masked_cross_entropy(output, targets, masks):
 ######################################################################
 
 
+def add_hints_(imt_set):
+    input, masks, targets = imt_set
+    h = torch.rand(masks.size(), device=masks.device) - masks
+    t = h.sort(dim=1).values[:, args.nb_hints, None]
+    mask_hints = (h < t).long()
+    masks[...] = (1 - mask_hints) * masks
+    input[...] = (1 - mask_hints) * input + mask_hints * targets
+
+
 def add_hints(masks, fraction_with_hints):
     if fraction_with_hints > 0:
-        h = torch.rand(masks.size(), device=masks.device) * masks
-        mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
-        v = torch.rand(masks.size(0), device=masks.device)[:, None]
-        mask_hints = mask_hints * (v < fraction_with_hints).long()
+        h = torch.rand(masks.size(), device=masks.device) - masks
+        t = h.sort(dim=1).values[:, args.nb_hints, None]
+        mask_hints = (h < t).long()
         return (1 - mask_hints) * masks
     else:
         return masks
@@ -387,19 +395,18 @@ def add_hints(masks, fraction_with_hints):
 # IMT for input / masks / target
 
 
-def batch_prediction_imt(input, fraction_with_hints=0.0):
+def batch_for_prediction_imt(input):
     nb = input.size(0)
     masks = input.new_zeros(input.size())
     u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
     masks.view(nb, 4, -1)[...] = u[:, :, None]
-    masks = add_hints(masks, fraction_with_hints)
     targets = input
     input = (1 - masks) * targets
 
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
-def predict(model, imt_set, local_device=main_device, desc="predict"):
+def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
     model.eval().to(local_device)
 
     record = []
@@ -428,20 +435,17 @@ def predict(model, imt_set, local_device=main_device, desc="predict"):
     return torch.cat(record)
 
 
-def predict_full(model, input, fraction_with_hints, local_device=main_device):
+def predict_full(model, input, local_device=main_device):
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
     nb = input.size(0)
     masks = input.new_zeros(input.size())
     u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
     masks.view(nb, 4, -1)[...] = u[:, :, None]
-    masks_with_hints = add_hints(masks, fraction_with_hints)
     targets = input
-    input = (1 - masks_with_hints) * targets
-    imt_set = torch.cat(
-        [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1
-    )
+    input = (1 - masks) * targets
+    imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
-    result = predict(model, imt_set, local_device=local_device, desc=None)
+    result = ae_predict(model, imt_set, local_device=local_device, desc=None)
     result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
 
     return result
@@ -450,7 +454,7 @@ def predict_full(model, input, fraction_with_hints, local_device=main_device):
 ######################################################################
 
 
-def batch_generation_imt(input):
+def batch_for_generation_imt(input):
     nb = input.size(0)
     probs_iterations = 0.1 ** torch.linspace(
         0, 1, args.diffusion_nb_iterations, device=input.device
@@ -480,7 +484,7 @@ def prioritized_rand(low):
     return y
 
 
-def generate(model, nb, local_device=main_device, desc="generate"):
+def ae_generate(model, nb, local_device=main_device, desc="generate"):
     model.eval().to(local_device)
 
     all_input = quiz_machine.pure_noise(nb, local_device)
@@ -533,8 +537,8 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
 
     imt_set = torch.cat(
         [
-            batch_prediction_imt(q1, fraction_with_hints=0.5),
-            batch_generation_imt(q2),
+            batch_for_prediction_imt(q1),
+            batch_for_generation_imt(q2),
         ]
     )
 
@@ -597,13 +601,13 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
         args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
     )
 
-    # Save some images of the prediction results (one grid at random)
+    # Save some images of the prediction results
 
     quizzes = quiz_machine.quiz_set(
         args.nb_test_samples, c_quizzes, args.c_quiz_multiplier
     )
-    imt_set = batch_prediction_imt(quizzes.to(local_device))
-    result = predict(model, imt_set, local_device=local_device).to("cpu")
+    imt_set = batch_for_prediction_imt(quizzes.to(local_device))
+    result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
 
     correct = (quizzes == result).min(dim=1).values.long()
@@ -631,7 +635,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     # Save some images of the ex nihilo generation of the four grids
 
-    result = generate(model, 150, local_device=local_device).to("cpu")
+    result = ae_generate(model, 150, local_device=local_device).to("cpu")
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_generation_{n_epoch}_{model.id}.png",
@@ -695,21 +699,21 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
 ######################################################################
 
 
-def generate_c_quizzes(models, nb, local_device=main_device):
+def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
     record = []
     nb_validated = 0
 
     start_time = time.perf_counter()
     last_log = -1
 
-    while nb_validated < nb:
+    while nb_validated < nb_to_generate:
         # Generate new quizzes
 
         model = models[torch.randint(len(models), (1,)).item()]
         model = copy.deepcopy(model).to(local_device).eval()
         generator_id = model.id
 
-        c_quizzes = generate(
+        c_quizzes = ae_generate(
             model=model,
             nb=args.physical_batch_size,
             local_device=local_device,
@@ -736,8 +740,8 @@ def generate_c_quizzes(models, nb, local_device=main_device):
         if last_log < 0 or duration > last_log + 10:
             last_log = duration
             if nb_validated > 0:
-                if nb_validated < nb:
-                    d = (nb - nb_validated) * duration / nb_validated
+                if nb_validated < nb_to_generate:
+                    d = (nb_to_generate - nb_validated) * duration / nb_validated
                     e = (
                         datetime.datetime.now() + datetime.timedelta(seconds=d)
                     ).strftime("%a %H:%M")
@@ -754,7 +758,7 @@ def generate_c_quizzes(models, nb, local_device=main_device):
 
     duration = time.perf_counter() - start_time
 
-    log_string(f"generate_c_quizz_speed {int(3600 * nb / duration)}/h")
+    log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
 
     return torch.cat(record).to("cpu")
 
@@ -848,7 +852,7 @@ if args.quizzes is not None:
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            result = generate(
+            result = ae_generate(
                 model,
                 (1 - mask_generate) * quizzes,
                 mask_generate,
@@ -874,8 +878,6 @@ if args.quizzes is not None:
 
 ######################################################################
 
-last_n_epoch_c_quizzes = 0
-
 c_quizzes = None
 
 time_c_quizzes = 0
@@ -963,7 +965,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         if c_quizzes is None:
             save_models(models, "naive")
 
-        last_n_epoch_c_quizzes = n_epoch
         nb_gpus = len(gpus)
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
index 594b5ca..e2f6d3b 100755 (executable)
@@ -203,6 +203,9 @@ class QuizMachine:
     def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1):
         if c_quizzes is None:
             quizzes = self.problem.generate_w_quizzes(nb_samples)
+            quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+                quizzes.size(0), -1
+            )
         else:
             if c_quiz_multiplier > 1:
                 n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
@@ -222,15 +225,14 @@ class QuizMachine:
                 c_quizzes = c_quizzes[i]
 
             w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+            w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+                w_quizzes.size(0), -1
+            )
             quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
 
         i = torch.randperm(quizzes.size(0), device=quizzes.device)
         quizzes = quizzes[i].contiguous()
 
-        quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
-            quizzes.size(0), -1
-        )
-
         return quizzes
 
     ######################################################################