From: François Fleuret Date: Fri, 13 Oct 2023 11:51:34 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f44ab6863f93ae348e66ffbf52251d96d3b5453c;p=picoclvr.git Update. --- diff --git a/qmlp.y b/qmlp.py similarity index 54% rename from qmlp.y rename to qmlp.py index 7b97edb..e12f0e1 100755 --- a/qmlp.y +++ b/qmlp.py @@ -39,40 +39,24 @@ def dequantize(q, xmin, xmax): ###################################################################### -def create_model(): - hidden_dim = 32 - - model = nn.Sequential( - nn.Linear(2, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, 2), - ) - - return model - - -###################################################################### def generate_sets_and_params( - nb_mlps, + batch_nb_mlps, nb_samples, batch_size, nb_epochs, device=torch.device("cpu"), print_log=False, ): - data_input = torch.zeros(nb_mlps, 2 * nb_samples, 2, device=device) + data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device) data_targets = torch.zeros( - nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device + batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device ) while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1: i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1 nb = i.sum() - print(f"{nb=}") nb_rec = 2 support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1 @@ -108,10 +92,10 @@ def generate_sets_and_params( test_targets = test_targets hidden_dim = 32 - w1 = torch.randn(nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2) - b1 = torch.zeros(nb_mlps, hidden_dim, device=device) - w2 = torch.randn(nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim) - b2 = torch.zeros(nb_mlps, 2, device=device) + w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2) + b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device) + w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim) + b2 = torch.zeros(batch_nb_mlps, 2, device=device) w1.requires_grad_() b1.requires_grad_() @@ -158,13 +142,13 @@ def generate_sets_and_params( # print(f"{k=} {acc_train_loss=} {train_error=}") q_params = torch.cat( - [quantize(p.view(nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1 + [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1 ) q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape( - nb_mlps, -1 + batch_nb_mlps, -1 ) q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape( - nb_mlps, -1 + batch_nb_mlps, -1 ) return q_train_set, q_test_set, q_params @@ -173,51 +157,59 @@ def generate_sets_and_params( ###################################################################### -def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu")): +def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024): + + errors = [] nb_mlps = q_params.size(0) - hidden_dim = 32 - w1 = torch.empty(nb_mlps, hidden_dim, 2, device=device) - b1 = torch.empty(nb_mlps, hidden_dim, device=device) - w2 = torch.empty(nb_mlps, 2, hidden_dim, device=device) - b2 = torch.empty(nb_mlps, 2, device=device) - - with torch.no_grad(): - k = 0 - for p in [w1, b1, w2, b2]: - print(f"{p.size()=}") - x = dequantize(q_params[:, k : k + p.numel() // nb_mlps], -2, 2).view( - p.size() - ) - p.copy_(x) - k += p.numel() // nb_mlps - q_set = q_set.view(nb_mlps, -1, 3) - data_input = dequantize(q_set[:, :, :2], -1, 1).to(device) - data_targets = q_set[:, :, 2].to(device) + for n in range(0,nb_mlps,nb_mlps_per_batch): + batch_nb_mlps = min(nb_mlps_per_batch,nb_mlps-n) + batch_q_params = q_params[n:n+batch_nb_mlps] + batch_q_set = q_set[n:n+batch_nb_mlps] + hidden_dim = 32 + w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device) + b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device) + w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device) + b2 = torch.empty(batch_nb_mlps, 2, device=device) - print(f"{data_input.size()=} {data_targets.size()=}") + with torch.no_grad(): + k = 0 + for p in [w1, b1, w2, b2]: + print(f"{p.size()=}") + x = dequantize(batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2).view( + p.size() + ) + p.copy_(x) + k += p.numel() // batch_nb_mlps - criterion = nn.CrossEntropyLoss() - criterion.to(device) + batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3) + data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device) + data_targets = batch_q_set[:, :, 2].to(device) + + print(f"{data_input.size()=} {data_targets.size()=}") + + criterion = nn.CrossEntropyLoss() + criterion.to(device) + + acc_loss = 0.0 + nb_errors = 0 - acc_loss = 0.0 - nb_errors = 0 + for input, targets in zip( + data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1)) + acc_loss += loss.item() * input.size(0) + wta = output.argmax(-1) + nb_errors += (wta != targets).long().sum(-1) - for input, targets in zip( - data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1) - ): - h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] - h = F.relu(h) - output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] - loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1)) - acc_loss += loss.item() * input.size(0) - wta = output.argmax(-1) - nb_errors += (wta != targets).long().sum(-1) + errors.append(nb_errors / data_input.size(1)) + acc_loss = acc_loss / data_input.size(1) - error = nb_errors / data_input.size(1) - acc_loss = acc_loss / data_input.size(1) - return error + return torch.cat(errors) ###################################################################### @@ -229,40 +221,41 @@ def generate_sequence_and_test_set( batch_size, nb_epochs, device, + nb_mlps_per_batch=1024, ): - q_train_set, q_test_set, q_params = generate_sets_and_params( - nb_mlps, - nb_samples, - batch_size, - nb_epochs, - device=device, - ) - input = torch.cat( - [ - q_train_set, - q_train_set.new_full( - ( - q_train_set.size(0), - 1, + inputs, q_test_sets = [],[] + + for n in range(0,nb_mlps,nb_mlps_per_batch): + q_train_set, q_test_set, q_params = generate_sets_and_params( + batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n), + nb_samples=nb_samples, + batch_size=batch_size, + nb_epochs=nb_epochs, + device=device, + ) + + inputs.append(torch.cat( + [ + q_train_set, + q_train_set.new_full( + ( + q_train_set.size(0), + 1, + ), + nb_quantization_levels, ), - nb_quantization_levels, - ), - q_params, - ], - dim=-1, - ) + q_params, + ], + dim=-1, + )) - print(f"SANITY #1 {q_train_set.size()=} {q_params.size()=} {input.size()=}") + q_test_sets.append(q_test_set) - ar_mask = ( - (torch.arange(input.size(0), device=input.device) > q_train_set.size(0) + 1) - .long() - .view(1, -1) - .reshape(nb_mlps, -1) - ) + input = torch.cat(inputs) + q_test_set = torch.cat(q_test_sets) - return input, ar_mask, q_test_set + return input, q_test_set ###################################################################### @@ -270,7 +263,7 @@ def generate_sequence_and_test_set( if __name__ == "__main__": import time - nb_mlps, nb_samples = 128, 200 + batch_nb_mlps, nb_samples = 128, 500 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -278,26 +271,22 @@ if __name__ == "__main__": data = [] - for n in range(2): - data.append( - generate_sequence_and_test_set( - nb_mlps=nb_mlps, - nb_samples=nb_samples, - device=device, - batch_size=25, - nb_epochs=250, - ) - ) + input, q_test_set = generate_sequence_and_test_set( + nb_mlps=batch_nb_mlps, + nb_samples=nb_samples, + device=device, + batch_size=25, + nb_epochs=250, + nb_mlps_per_batch=17 + ) end_time = time.perf_counter() - nb = sum([i.size(0) for i, _, _ in data]) - print(f"{nb / (end_time - start_time):.02f} samples per second") - - for input, ar_mask, q_test_set in data: - q_train_set = input[:, : nb_samples * 3] - q_params = input[:, nb_samples * 3 + 1 :] - print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}") - error_train = evaluate_q_params(q_params, q_train_set) - print(f"train {error_train*100}%") - error_test = evaluate_q_params(q_params, q_test_set) - print(f"test {error_test*100}%") + print(f"{input.size(0) / (end_time - start_time):.02f} samples per second") + + q_train_set = input[:, : nb_samples * 3] + q_params = input[:, nb_samples * 3 + 1 :] + print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}") + error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17) + print(f"train {error_train*100}%") + error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17) + print(f"test {error_test*100}%") diff --git a/tasks.py b/tasks.py index 183c3cf..ea10d7c 100755 --- a/tasks.py +++ b/tasks.py @@ -1550,3 +1550,105 @@ class Grid(Task): ###################################################################### + +import qmlp + + +class QMLP(Task): + + ###################### + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.device = device + self.batch_size = batch_size + + if logger is not None: + logger( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + + self.train_descr = self.grid_factory.generate_samples( + nb_train_samples, lambda r: tqdm.tqdm(r) + ) + self.test_descr = self.grid_factory.generate_samples( + nb_test_samples, lambda r: tqdm.tqdm(r) + ) + + # Build the tokenizer + tokens = set() + for d in [self.train_descr, self.test_descr]: + for s in d: + for t in s.strip().split(" "): + tokens.add(t) + # make this set a sorted list to get the same tensors given + # the same descr + tokens = list(tokens) + tokens.sort() + tokens = ["#"] + tokens + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + self.t_nul = self.token2id["#"] + self.t_true = self.token2id["true"] + self.t_false = self.token2id["false"] + + # Tokenize the train and test sets + self.train_input = self.str2tensor(self.train_descr) + self.test_input = self.str2tensor(self.test_descr) + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield self.trim(batch) + + def vocabulary_size(self): + return len(self.token2id) + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + correct = self.test_input[:1000] + result = correct.clone() + ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long() + result *= 1 - ar_mask # paraaaaanoiaaaaaaa + + logger(f"----------------------------------------------------------") + + for e in self.tensor2str(result[:10]): + logger(f"test_before {e}") + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + logger(f"----------------------------------------------------------") + + for e in self.tensor2str(result[:10]): + logger(f"test_after {e}") + + logger(f"----------------------------------------------------------") + + nb_total = ar_mask.sum().item() + nb_correct = ((correct == result).long() * ar_mask).sum().item() + + logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}") + logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}") + + +######################################################################