From: François Fleuret Date: Thu, 11 Jul 2024 15:52:40 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=a86dff174205c38d8e90d0d89ea399a6afb36359;p=culture.git Update. --- diff --git a/main.py b/main.py index 4cf4d59..a7338c7 100755 --- a/main.py +++ b/main.py @@ -18,6 +18,8 @@ import sky, grids, quiz_machine import threading +import torch.multiprocessing as mp + # world quizzes vs. culture quizzes ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index ae14614..8ab5696 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -424,17 +424,23 @@ class QuizMachine: ) for model in models: - for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) - ): - input = input.to(self.device) - ar_mask = self.make_ar_mask(input) - output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * ar_mask - ) - l[:, model.id] = -ce.sum(dim=-1) + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + input = input.to(self.device) + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = -ce.sum(dim=-1) + + model.train(t) return logproba.to("cpu")