From: François Fleuret Date: Sun, 8 Sep 2024 07:44:29 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=895c492b8e65f9f73a22692131d96915132ccd17;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 9e80f62..73e722e 100755 --- a/grids.py +++ b/grids.py @@ -709,20 +709,22 @@ class Grids(problem.Problem): nb_rec = 3 c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec, prevent_overlap=True) + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + if min([x[2] for x in r]) > self.height // 2 + 1: + break for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n] - X[: self.height // 2] = c[-1] + X[: self.height // 2] = 0 f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2] if a == 1: + X[...] = X.flip((0,)) + f_X[...] = f_X.flip((0,)) + if b == 1: X[...] = X.clone().t() f_X[...] = f_X.clone().t() - if b == 1: - Z = X.clone() - X[...] = f_X - f_X[...] = Z # @torch.compile def task_translate(self, A, f_A, B, f_B): diff --git a/main.py b/main.py index 264b5c7..a4030ff 100755 --- a/main.py +++ b/main.py @@ -57,6 +57,10 @@ parser.add_argument("--nb_train_samples", type=int, default=25000) parser.add_argument("--nb_test_samples", type=int, default=1000) +parser.add_argument("--nb_train_alien_samples", type=int, default=0) + +parser.add_argument("--nb_test_alien_samples", type=int, default=0) + parser.add_argument("--nb_c_quizzes", type=int, default=2500) parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) @@ -304,7 +308,6 @@ alien_quiz_machine = quiz_machine.QuizMachine( logger=log_string, device=main_device, ) - # ------------------------------------------------------ ###################################################################### @@ -366,121 +369,13 @@ def optimizer_to(optim, device): subparam._grad.data = subparam._grad.data.to(device) -###################################################################### - -from mygpt import ( - WithResidual, - CacheWrapper, - VaswaniPositionalEncoding, - TrainablePositionalEncoding, - QKVAttention, - BracketedSequence, -) - - -class Thinker(nn.Module): - def __init__( - self, - vocabulary_size, - dim_model, - dim_keys, - dim_hidden, - nb_heads, - nb_blocks, - f_len, - dropout=0.0, - len_max=1e5, - ): - super().__init__() - - assert dim_model % nb_heads == 0 - - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - VaswaniPositionalEncoding(len_max), - ) - - def trunk(depth): - trunk_blocks = [] - - for b in range(nb_blocks): - trunk_blocks += [ - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - ), - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention_dropout=dropout, - ), - ), - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ), - ] - - return nn.Sequential(*trunk_blocks) - - self.bottom_trunk = trunk(nb_blocks // 2) - - self.top_trunk = trunk(nb_blocks // 2) - - self.readout = CacheWrapper( - nn.Linear(in_features=dim_model, out_features=vocabulary_size) - ) - - self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model)) - - with torch.no_grad(): - for m in self.modules(): - if isinstance(m, nn.Embedding): - m.weight.normal_(mean=0, std=2e-2) - elif isinstance(m, nn.LayerNorm): - m.bias.zero_() - m.weight.fill_(1.0) - - def forward(self, bs): - for m in self.modules(): - m.loss = 0 - - L = bs.x.size(1) // 3 - - bs = self.embedding(bs) - A_fA = BracketedSequence(bs.x[:, : 2 * L]) - B = BracketedSequence(bs.x[:, -L:]) - - bs = BracketedSequence( - torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1) - ) - bs = self.bottom_trunk(bs) - bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1)) - bs = self.top_trunk(bs) - bs = BracketedSequence(bs.x[:, f_len:, :]) - bs = self.readout(bs) - - for m in self.modules(): - if m is not self: - self.loss += m.loss - - return bs - - ###################################################################### from mygpt import ( WithResidual, CacheWrapper, - VaswaniPositionalEncoding, + CachedVaswaniPositionalEncoding, QKVAttention, BracketedSequence, ) @@ -548,7 +443,7 @@ class MyAttentionAE(nn.Module): ) # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max) - self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5) + self.positional_encoding = CachedVaswaniPositionalEncoding(len_max=1e5) trunk_blocks = [] @@ -582,137 +477,6 @@ class MyAttentionAE(nn.Module): return bs -###################################################################### - -# f = phi(A, f(A)) + phi(B, f(B)) -# \hat{f(A)} = psi(A, f) -# \hat{A} = psi_inv(f(A), f) -# \hat{f(B)} = psi(B, f) -# \hat{B} = psi_inv(f(B), f) - - -def attention_layer(dim_model, dim_keys, nb_heads, dropout): - return WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - ), - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention_dropout=dropout, - ), - ) - - -class FunctionalAE(nn.Module): - def __init__( - self, - vocabulary_size, - dim_model, - dim_keys, - dim_hidden, - nb_heads, - nb_blocks, - dropout=0.0, - len_max=1024, - ): - super().__init__() - - assert dim_model % nb_heads == 0 - - self.embedding = CacheWrapper( - nn.Sequential( - MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout) - ), - ) - - # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max) - self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5) - - def trunk(nb, bottom=True): - trunk_blocks = [VaswaniPositionalEncoding(len_max=1e5)] - - la = [ - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention_dropout=dropout, - ), - ] - - # if not bottom: - # trunk_blocks += la - - for b in range(nb): - trunk_blocks += [ - attention_block(dim_model, dim_keys, nb_heads, dropout), - ffw_block(dim_model, dim_hidden, nb_heads, dropout), - ] - - # if bottom: - # trunk_blocks += la - - return nn.Sequential(*trunk_blocks) - - self.phi = trunk(nb_blocks // 2, bottom=True) - nb_f_tokens = 200 - self.f_tokens = nn.Parameter( - torch.randn(1, nb_f_tokens, dim_model) / math.sqrt(nb_f_tokens) - ) - self.psi = trunk(nb_blocks // 2, bottom=False) - self.psi_inv = trunk(nb_blocks // 2, bottom=False) - self.internal_pe = VaswaniPositionalEncoding(len_max=1e5) - - self.readout = CacheWrapper( - nn.Linear(in_features=dim_model, out_features=vocabulary_size) - ) - - with torch.no_grad(): - for m in self.modules(): - if isinstance(m, nn.Embedding): - m.weight.normal_(mean=0, std=2e-2) - elif isinstance(m, nn.LayerNorm): - m.bias.zero_() - m.weight.fill_(1.0) - - def forward(self, bs): - def cat(*x): - return BracketedSequence(torch.cat(x, dim=1)) - - if torch.is_tensor(bs): - return self.forward(BracketedSequence(bs)).x - bs = self.embedding(bs) - bs = self.positional_encoding(bs) - - x_A, x_f_A, x_B, x_f_B = bs.x.chunk(4, dim=1) - - K = self.f_tokens.size(1) - N, L = x_A.size()[:2] - - ft = self.f_tokens.expand(N, -1, -1) - - theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :] - theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :] - - # if self.hook_theta is not None: - # self.hook_theta(theta_A, theta_B) - - hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L] - hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L] - - hat_A = self.psi_inv(cat(x_f_A, theta_B)).x[:, :L] - hat_B = self.psi_inv(cat(x_f_B, theta_A)).x[:, :L] - - bs = cat(hat_A, hat_f_A, hat_B, hat_f_B) - - bs = self.readout(bs) - return bs - - ###################################################################### # quad_order, quad_generate, quad_noise, quad_loss @@ -732,6 +496,8 @@ def ae_batches( data_structures, local_device, c_quizzes=None, + alien_quiz_machine=None, + nb_aliens=None, desc=None, batch_size=args.batch_size, ): @@ -1149,24 +915,25 @@ def run_ae_test( f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" ) - model.test_accuracy = nb_correct / nb_total - # Save some images - for f, record in [("prediction", record_d), ("generation", record_nd)]: - result, predicted_parts, correct_parts = bag_to_tensors(record) + if n_epoch < 50: + for f, record in [("prediction", record_d), ("generation", record_nd)]: + result, predicted_parts, correct_parts = bag_to_tensors(record) - filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result[:128], - predicted_parts=predicted_parts[:128], - correct_parts=correct_parts[:128], - ) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result[:128], + predicted_parts=predicted_parts[:128], + correct_parts=correct_parts[:128], + ) - log_string(f"wrote {filename}") + log_string(f"wrote {filename}") + + return nb_correct / nb_total ###################################################################### @@ -1209,7 +976,19 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device) + model.test_accuracy = run_ae_test( + model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device + ) + + if args.nb_test_alien_samples > 0: + run_ae_test( + model, + alien_quiz_machine, + n_epoch, + c_quizzes=None, + local_device=local_device, + prefix="alien", + ) ###################################################################### @@ -1308,6 +1087,10 @@ def quiz_validation(models, c_quizzes, local_device): def generate_ae_c_quizzes(models, nb, local_device=main_device): # To be thread-safe we must make copies + + def copy_for_inference(model): + return copy.deepcopy(model).to(local_device).eval() + quad_order = ("A", "f_A", "B", "f_B") template = quiz_machine.problem.create_empty_quizzes( @@ -1318,9 +1101,6 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) ) - def copy_for_inference(model): - return copy.deepcopy(model).to(local_device).eval() - wanted_nb = nb nb_to_save = 256 nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device) diff --git a/mygpt.py b/mygpt.py index a744224..5b56264 100755 --- a/mygpt.py +++ b/mygpt.py @@ -110,7 +110,7 @@ class CacheWrapper(nn.Module): ############################## -class WithResidual(nn.Module): +class CachedWithResidual(nn.Module): def __init__(self, *f): super().__init__() self.f = f[0] if len(f) == 1 else nn.Sequential(*f) @@ -122,7 +122,7 @@ class WithResidual(nn.Module): ############################## -class VaswaniPositionalEncoding(nn.Module): +class CachedVaswaniPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max @@ -358,13 +358,13 @@ class MyGPT(nn.Module): CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), ) - self.positional_encoding = VaswaniPositionalEncoding(len_max) + self.positional_encoding = CachedVaswaniPositionalEncoding(len_max) trunk_blocks = [] for b in range(nb_blocks): trunk_blocks += [ - WithResidual( + CachedWithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), NoiseInjector(identifier=("attention", b)), @@ -378,7 +378,7 @@ class MyGPT(nn.Module): attention_dropout=dropout, ), ), - WithResidual( + CachedWithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), NoiseInjector(identifier=("ffw", b)),