From: François Fleuret Date: Sat, 27 Jul 2024 03:38:31 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0d3bda1a286803536e7fb3dce4e1ff7c7a9de942;p=culture.git Update. --- diff --git a/main.py b/main.py index 3787e9f..848ac9c 100755 --- a/main.py +++ b/main.py @@ -645,7 +645,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### -#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! def train_auto_encoder(): @@ -658,9 +657,10 @@ def train_auto_encoder(): nb_blocks=args.nb_blocks, causal=False, dropout=args.dropout, - auto_encoder_dim=64, ).to(main_device) + model.make_auto_encoder(auto_encoder_dim=64) + test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) diff --git a/mygpt.py b/mygpt.py index 9bec09e..b38cc99 100755 --- a/mygpt.py +++ b/mygpt.py @@ -256,7 +256,6 @@ class MyGPT(nn.Module): nb_blocks, causal=False, dropout=0.0, - auto_encoder_dim=-1, len_max=1e5, ): super().__init__() @@ -298,24 +297,6 @@ class MyGPT(nn.Module): ), ] - if auto_encoder_dim > 0: - self.encoder = nn.Sequential( - *( - trunk_blocks[: nb_blocks // 2] - + [EncoderHead(dim_model, auto_encoder_dim)] - ) - ) - - self.decoder = nn.Sequential( - *( - [ - DecoderBottom(auto_encoder_dim, dim_model), - AddPositionalEncoding(len_max), - ] - + trunk_blocks[nb_blocks // 2 :] - ) - ) - self.trunk = nn.Sequential(*trunk_blocks) self.readout = CacheWrapper( @@ -337,6 +318,24 @@ class MyGPT(nn.Module): bs = self.readout(bs) return bs + def make_auto_encoder(self, auto_encoder_dim): + self.encoder = nn.Sequential( + *( + trunk_blocks[: nb_blocks // 2] + + [EncoderHead(dim_model, auto_encoder_dim)] + ) + ) + + self.decoder = nn.Sequential( + *( + [ + DecoderBottom(auto_encoder_dim, dim_model), + AddPositionalEncoding(len_max), + ] + + trunk_blocks[nb_blocks // 2 :] + ) + ) + def encode(self, bs): bs = self.embedding(bs) z = self.encoder(bs)