From: François Fleuret Date: Sun, 8 Sep 2024 10:31:09 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=d12981e75482b80a73f809259ea62150754dd53f;p=culture.git Update. --- diff --git a/attae.py b/attae.py index 7bd4a44..e9e4bff 100755 --- a/attae.py +++ b/attae.py @@ -102,7 +102,7 @@ class AttentionAE(nn.Module): assert dim_model % nb_heads == 0 self.embedding = nn.Sequential( - nn.Embedding(vocabulary_size, dim_model), + nn.Embedding(2 * vocabulary_size, dim_model), nn.Dropout(dropout), ) @@ -143,7 +143,8 @@ class AttentionAE(nn.Module): m.bias.zero_() m.weight.fill_(1.0) - def forward(self, x, mask=None): + def forward(self, x): + x = 2 * x[:, :, 0] + x[:, :, 1] x = self.embedding(x) x = self.positional_encoding(x) x = self.trunk(x) diff --git a/main.py b/main.py index d90a3df..9285337 100755 --- a/main.py +++ b/main.py @@ -999,8 +999,8 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi models = [] for i in range(args.nb_models): - model = MyAttentionAE( - # model = attae.AttentionAE( + # model = MyAttentionAE( + model = attae.AttentionAE( vocabulary_size=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, @@ -1338,6 +1338,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): else: log_string(f"nb_c_quizzes {c_quizzes.size(0)}") + # one_ae_epoch(model, quiz_machine, n_epoch, None) + # exit(0) + # -------------------------------------------------------------------- ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))