quizzes[i] = self.problem.inject_noise(
quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
)
- quiz_mask_loss[i] = self.make_ar_mask(
+ quiz_mask_loss[i] = self.make_quiz_mask(
quizzes=quizzes[i], struct=struct, mask=mask_loss
)
######################################################################
- def make_ar_mask(self, quizzes, struct, mask):
+ def make_quiz_mask(self, quizzes, struct, mask):
assert struct in [s for s, _, _, _ in self.train_structures]
- return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
+ return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask)
######################################################################
def predict(self, model, quizzes, struct, mask):
- ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
+ ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
result = quizzes * (1 - ar_mask)
seq_logproba = torch.empty(quizzes.size(0), device=self.device)
seq_logproba.split(self.batch_size),
):
input = input.to(device)
- quiz_mask_loss = self.make_ar_mask(
+ quiz_mask_loss = self.make_quiz_mask(
input, struct=struct, mask=mask_loss
)
output = model(mygpt.BracketedSequence(input)).x
self.autoregression(
model=model_for_generation,
input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes, s, m),
+ ar_mask=self.make_quiz_mask(c_quizzes, s, m),
seq_logproba=seq_logproba,
)