From: François Fleuret Date: Sat, 27 Jul 2024 03:31:22 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=4754fbcd96ba849dbc62b3b83e4cab0cae0c0621;p=culture.git Update. --- diff --git a/main.py b/main.py index ffdd16f..3787e9f 100755 --- a/main.py +++ b/main.py @@ -95,11 +95,11 @@ parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--temperature_hot", type=float, default=2) +parser.add_argument("--temperature_hot", type=float, default=1.25) -parser.add_argument("--temperature_cold", type=float, default=0.75) +parser.add_argument("--temperature_cold", type=float, default=1.25) -parser.add_argument("--nb_rounds", type=int, default=1) +parser.add_argument("--nb_rounds", type=int, default=2) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") @@ -645,6 +645,94 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### +#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + +def train_auto_encoder(): + model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + causal=False, + dropout=args.dropout, + auto_encoder_dim=64, + ).to(main_device) + + test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + nb_train_samples, acc_train_loss = 0, 0.0 + + for n_epoch in range(args.nb_epochs): + train_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) + for input in tqdm.tqdm( + train_w_quizzes.split(args.batch_size), + dynamic_ncols=True, + desc="training AE", + total=train_w_quizzes.size(0) // args.batch_size, + ): + model.train() + l = input.size(1) // 4 + input = input[:, -l:].to(main_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device))) + output = model.decode(z_shape).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}") + + filename = f"auto_encoder.pth" + torch.save( + model.state_dict(), + os.path.join(args.result_dir, filename), + ) + log_string(f"wrote {filename}") + + with torch.autograd.no_grad(): + model.eval() + input = test_w_quizzes[:128, -l:] + z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device))) + logits = model.decode(z_shape).x + + # dist = torch.distributions.categorical.Categorical(logits=logits) + # q = dist.sample() + + q = logits.argmax(dim=-1) + q = q.reshape(q.size(0) // 2, 2, -1) + input = input.reshape(input.size(0) // 2, 2, -1) + q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"culture_ae_{n_epoch:04d}.png", + q, + ) + + return model + + +# ae = train_auto_encoder() + +# exit(0) + +###################################################################### + models = [] diff --git a/mygpt.py b/mygpt.py index 51c0862..9bec09e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -114,6 +114,30 @@ class AddPositionalEncoding(nn.Module): ############################## +class EncoderHead(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, bs): + z = self.fc(bs.x).mean(dim=1) + return z, bs.x.shape + + +class DecoderBottom(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, z_shape): + z, shape = z_shape + y = self.fc(z)[:, None, :].expand(shape) + return BracketedSequence(y) + + +############################## + + class QKVAttention(nn.Module): def __init__( self, @@ -232,6 +256,7 @@ class MyGPT(nn.Module): nb_blocks, causal=False, dropout=0.0, + auto_encoder_dim=-1, len_max=1e5, ): super().__init__() @@ -273,6 +298,24 @@ class MyGPT(nn.Module): ), ] + if auto_encoder_dim > 0: + self.encoder = nn.Sequential( + *( + trunk_blocks[: nb_blocks // 2] + + [EncoderHead(dim_model, auto_encoder_dim)] + ) + ) + + self.decoder = nn.Sequential( + *( + [ + DecoderBottom(auto_encoder_dim, dim_model), + AddPositionalEncoding(len_max), + ] + + trunk_blocks[nb_blocks // 2 :] + ) + ) + self.trunk = nn.Sequential(*trunk_blocks) self.readout = CacheWrapper( @@ -288,13 +331,22 @@ class MyGPT(nn.Module): m.weight.fill_(1.0) def forward(self, bs): - # print(f"GENERATE {bs.first} {bs.first+bs.nb}") bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) bs = self.trunk(bs) bs = self.readout(bs) return bs + def encode(self, bs): + bs = self.embedding(bs) + z = self.encoder(bs) + return z + + def decode(self, z_shape): + bs = self.decoder(z_shape) + bs = self.readout(bs) + return bs + def partial_forward(self, bs, start_layer=None, end_layer=None): if start_layer is None: # print(f"GENERATE {bs.first} {bs.first+bs.nb}") diff --git a/quiz_machine.py b/quiz_machine.py index a9319c7..7516aed 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -496,13 +496,17 @@ class QuizMachine: ###################################################################### - def generate_c_quizzes_simple( + def generate_c_quizzes_( self, nb, model_for_generation, temperature_hot=1.0, temperature_cold=1.0, ): + warnings.warn( + "**************************** simple quiz generation", RuntimeWarning + ) + c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) c_quizzes = c_quizzes.to(self.device)