From 32945dec8f4c0b0406c58c7c9d21bdc0b9c6e73a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 13 Jun 2025 14:24:17 +0200 Subject: [PATCH] Update. --- tinyae.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tinyae.py b/tinyae.py index b4f3aba..0baa5a2 100755 --- a/tinyae.py +++ b/tinyae.py @@ -124,20 +124,22 @@ test_input.sub_(mu).div_(std) ###################################################################### -for epoch in range(args.nb_epochs): - acc_loss = 0 +for n_epoch in range(args.nb_epochs): + acc_train_loss = 0 for input in train_input.split(args.batch_size): output = model(input) - loss = 0.5 * (output - input).pow(2).sum() / input.size(0) + train_loss = F.mse_loss(output, input) optimizer.zero_grad() - loss.backward() + train_loss.backward() optimizer.step() - acc_loss += loss.item() + acc_train_loss += train_loss.detach().item() * input.size(0) - log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss)) + train_loss = acc_train_loss / train_input.size(0) + + log_string(f"train_loss {n_epoch} {train_loss}") ###################################################################### -- 2.39.5