From: François Fleuret Date: Mon, 27 Mar 2023 13:38:16 +0000 (+0200) Subject: Update X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=9c4325b877ede05e14699ddae211d1edc83c1515;p=beaver.git Update --- diff --git a/beaver.py b/beaver.py index 5407859..074e137 100755 --- a/beaver.py +++ b/beaver.py @@ -133,18 +133,6 @@ for n in vars(args): ###################################################################### -def generation_order(x, prompt_len=0): - if args.random_regression_order: - order = torch.rand(x.size(), device=x.device) - order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device) - order = order.sort(1).indices - else: - order = ( - torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1) - ) - return order - - def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT' u = x.reshape(x.size()[:2] + (-1,)) order = order.unsqueeze(-1).expand(-1, -1, u.size(-1)) @@ -157,7 +145,14 @@ def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT' def shuffle(x, prompt_len): - order = generation_order(x, prompt_len) + if args.random_regression_order: + order = torch.rand(x.size(), device=x.device) + order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device) + order = order.sort(1).indices + else: + order = ( + torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1) + ) return reorder(x, order), order