From: François Fleuret Date: Sat, 17 Feb 2024 08:53:23 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=26af4588b06ed463a4f9b9bcc4b527dd4c864d49;p=mygptrnn.git Update. --- diff --git a/fridge b/fridge index 143092c..a4d860b 100644 --- a/fridge +++ b/fridge @@ -335,3 +335,73 @@ class Calibrator: ) % k_star.size(0) k_star = k_star[l_barrel, t_barrel] + +###################################################################### + +2024 Feb 15 23:10:50 (from main.py) + + +def add_memex_v4(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u = torch.rand(t.size(), device=t.device) + u[:, : input.size(1)] = 1.0 + memex_v3_proba_fragment = 1 / 20 + u = (u < memex_v3_proba_fragment).long() + v = u * torch.randint(input.size(1), u.size()) + u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[ + :, : input.size(1) - 1 + ] * input.size(1) + u = u.cumsum().clamp(min=0) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + yield new_input + + yield input + + + +###################################################################### + +2024 Feb 16 17:07:48 (from main.py) + + # ||gn + lambda * gm|| = max(||gn||,||gm||) + # ||gn||^2 + lambda + lambda^2||gm||^2 = max(||gn||^2,||gm||^2) + # A = ||gm||^2 B = C = ||gn||^2 - max(||gn||^2, ||gm||^2) + +###################################################################### + +2024 Feb 16 17:07:51 (from main.py) + + # A,B,C = gmgm, gngm, gngn - max(gngn,gmgm) + # Delta = B*B - 4*A*C + # if(delta >= 0): + # l = ( -B - sqrt(Delta))/(2*A) + # ||gn||+l*rho*||gm|| = max(||gn||,rho*||gm||) diff --git a/main.py b/main.py index 6254807..2a90fd1 100755 --- a/main.py +++ b/main.py @@ -453,6 +453,7 @@ except FileExistsError: exit(1) loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a") +lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a") log_file = open(os.path.join(args.result_dir, args.log_filename), "a") @@ -530,7 +531,7 @@ def get_lr(n_epoch, it): ###################################################################### -def add_memex_v2(batches, memex_proba, marker_token): +def add_memex_v1(batches, memex_proba, marker_token): for input in batches: if torch.rand(1).item() < memex_proba: t = ( @@ -561,61 +562,45 @@ def add_memex_v2(batches, memex_proba, marker_token): new_input = input[n, t.clamp(min=0)] new_input = (1 - m) * new_input + m * (marker_token) - yield new_input + memex_mask = new_input.new_zeros(new_input.size()) + memex_mask[:, input.size(1) :] = 1.0 + + yield new_input, memex_mask yield input -def add_memex_v3(batches, memex_proba, marker_token): +def add_memex_v2(batches, memex_proba): for input in batches: if torch.rand(1).item() < memex_proba: - t = ( - torch.arange(2 * input.size(1), device=input.device)[None, :] - .expand(input.size(0), -1) - .clone() + t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand( + input.size(0), -1 ) - - u = torch.rand(t.size(), device=t.device) - u[:, : input.size(1)] = 1.0 - memex_v3_proba_fragment = 1 / 20 - u = (u < memex_v3_proba_fragment).long() - v = u * torch.randint(input.size(1), u.size()) - u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[ - :, : input.size(1) - 1 - ] * input.size(1) - u = u.cumsum().clamp(min=0) - - u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) - caterpillar_length = args.nb_lines // args.caterpillar_height - u1 = ( - u0 - + torch.randint( - caterpillar_length, (input.size(0), 1), device=input.device - ) - + 1 + t = t + torch.randint( + input.size(1) - t.size(1), (t.size(0), 1), device=t.device ) - - m0 = (t < u0).long() - m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() - - t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 - m = (t < 0).long() n = torch.arange(input.size(0), device=input.device)[:, None].expand( -1, t.size(1) ) - new_input = input[n, t.clamp(min=0)] - new_input = (1 - m) * new_input + m * (marker_token) + flash = input[n, t] + new_input = torch.cat([input, flash], dim=1) - yield new_input + memex_mask = new_input.new_zeros(new_input.size()) + memex_mask[:, input.size(1) :] = 1.0 - yield input + yield new_input, memex_mask + + else: + yield input ###################################################################### assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} +assert args.batch_size % args.physical_batch_size == 0 + def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) @@ -978,6 +963,21 @@ it = 0 n_batch = 0 + +def the_dot_products(value1, value2, params): + g1g1, g1g2, g2g2 = 0, 0, 0 + for p in params: + g1 = torch.autograd.grad(value1, p, retain_graph=True)[0] + g2 = torch.autograd.grad(value2, p, retain_graph=True)[0] + g1g1 += g1.pow(2).sum()[None] + g2g2 += g2.pow(2).sum()[None] + g1g2 += (g1 * g2).sum()[None] + return torch.cat([g1g1, g1g2, g2g2]) + + +movave_dot_products = 0 + + for n_epoch in range(nb_epochs_finished, nb_epochs): if args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) @@ -1003,7 +1003,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): train_batches = add_memex_v2( batches=task.batches(split="train"), memex_proba=memex_proba, - marker_token=vocabulary_size - 1, ) def add_none(it): @@ -1015,11 +1014,37 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): for input in add_none(train_batches): if input is not None: + if type(input) is tuple: + input, memex_mask = input + memex_mask = memex_mask.to(device) + else: + memex_mask = None + model.reset_inner_loss() input = input.to(device) output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) + + if memex_mask is None: + loss = F.cross_entropy(output.transpose(1, 2), input) + else: + loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none") + loss_regular = (loss * (1 - memex_mask)).mean() + loss_memex = (loss * memex_mask).mean() + + if not torch.is_tensor(movave_dot_products) or torch.rand(1) < 0.01: + dot_products = the_dot_products( + loss_regular, loss_memex, model.parameters() + ) + eps = 1e-3 + movave_dot_products = ( + 1 - eps + ) * movave_dot_products + eps * dot_products + + grgr, grgm, gmgm = movave_dot_products + l = (max(grgr, gmgm) - grgr) / gmgm + loss = loss_regular + l * loss_memex + inner_loss = model.get_inner_loss() acc_train_loss += loss.item() * input.size(0) @@ -1047,10 +1072,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): optimizer.step() grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + grgr, grgm, gmgm = movave_dot_products + l = (max(grgr, rho * gmgm) - grgr) / (rho * gmgm) + lambda_file.write(f"{n_epoch} {n_batch} {l} {grgr} {gmgm}\n") optimizer.zero_grad() nb_acc_samples = 0 - - n_batch += 1 + n_batch += 1 with torch.autograd.no_grad(): model.eval()