From: François Fleuret Date: Fri, 13 Jun 2025 12:24:17 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=32945dec8f4c0b0406c58c7c9d21bdc0b9c6e73a;p=pytorch.git Update. --- diff --git a/tinyae.py b/tinyae.py index b4f3aba..0baa5a2 100755 --- a/tinyae.py +++ b/tinyae.py @@ -124,20 +124,22 @@ test_input.sub_(mu).div_(std) ###################################################################### -for epoch in range(args.nb_epochs): - acc_loss = 0 +for n_epoch in range(args.nb_epochs): + acc_train_loss = 0 for input in train_input.split(args.batch_size): output = model(input) - loss = 0.5 * (output - input).pow(2).sum() / input.size(0) + train_loss = F.mse_loss(output, input) optimizer.zero_grad() - loss.backward() + train_loss.backward() optimizer.step() - acc_loss += loss.item() + acc_train_loss += train_loss.detach().item() * input.size(0) - log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss)) + train_loss = acc_train_loss / train_input.size(0) + + log_string(f"train_loss {n_epoch} {train_loss}") ######################################################################