From: François Fleuret Date: Fri, 13 Oct 2023 22:28:24 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=3e3bf1003aa0ecbf7d38b7b0c289fbe1cfa3101b;p=picoclvr.git Update. --- diff --git a/qmlp.py b/qmlp.py index 572cde1..b58598a 100755 --- a/qmlp.py +++ b/qmlp.py @@ -53,12 +53,14 @@ def generate_sets_and_params( batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device ) + nb_rec = 8 + nb_values = 2 # more increases the min-max gap + + rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device) + while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1: i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1 nb = i.sum() - - nb_rec = 8 - nb_values = 2 # more increases the min-max gap support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1 support = support.sort(-1).values support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4) @@ -75,7 +77,7 @@ def generate_sets_and_params( .values ) - data_input[i], data_targets[i] = x, y + data_input[i], data_targets[i], rec_support[i] = x, y, support train_input, train_targets = ( data_input[:, :nb_samples], @@ -85,15 +87,39 @@ def generate_sets_and_params( q_train_input = quantize(train_input, -1, 1) train_input = dequantize(q_train_input, -1, 1) - train_targets = train_targets q_test_input = quantize(test_input, -1, 1) test_input = dequantize(q_test_input, -1, 1) - test_targets = test_targets if save_as_examples: - for k in range(q_train_input.size(0)): - with open(f"example_{k:04d}.dat", "w") as f: + a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1 + xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1), + a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2) + xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1) + print(f"{xf.size()=} {x.size()=}") + yf = ( + ( + (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long() + * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long() + * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long() + * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long() + ) + .max(dim=1) + .values + ) + + full_input, full_targets = xf,yf + + q_full_input = quantize(full_input, -1, 1) + full_input = dequantize(q_full_input, -1, 1) + + for k in range(q_full_input[:10].size(0)): + with open(f"example_full_{k:04d}.dat", "w") as f: + for u, c in zip(full_input[k], full_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") + + for k in range(q_train_input[:10].size(0)): + with open(f"example_train_{k:04d}.dat", "w") as f: for u, c in zip(train_input[k], train_targets[k]): f.write(f"{c} {u[0].item()} {u[1].item()}\n") @@ -293,7 +319,7 @@ def generate_sequence_and_test_set( if __name__ == "__main__": import time - batch_nb_mlps, nb_samples = 128, 2500 + batch_nb_mlps, nb_samples = 128, 250 generate_sets_and_params( batch_nb_mlps=10, diff --git a/tasks.py b/tasks.py index 44599f7..b33dee2 100755 --- a/tasks.py +++ b/tasks.py @@ -1588,13 +1588,19 @@ class QMLP(Task): self.train_input = seq[:nb_train_samples] self.train_q_test_set = q_test_set[:nb_train_samples] + self.train_ref_test_errors = test_error[:nb_train_samples] self.test_input = seq[nb_train_samples:] self.test_q_test_set = q_test_set[nb_train_samples:] - self.ref_test_errors = test_error + self.test_ref_test_errors = test_error[nb_train_samples:] + + filename = os.path.join(result_dir, f"train_errors_ref.dat") + with open(filename, "w") as f: + for e in self.train_ref_test_errors: + f.write(f"{e}\n") filename = os.path.join(result_dir, f"test_errors_ref.dat") with open(filename, "w") as f: - for e in self.ref_test_errors: + for e in self.test_ref_test_errors: f.write(f"{e}\n") self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1