From 240c59a123f0ed07c3faa0d02589c1348f41b365 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 17 Sep 2024 15:12:40 +0200 Subject: [PATCH] Update. --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index a353868..51e0fa2 100755 --- a/main.py +++ b/main.py @@ -433,12 +433,12 @@ def predict(model, imt_set, local_device=main_device): ): # some paranoia imt = imt.clone() - imt[:, 0] = imt[:, 0] * (1 - imt[:1]) + imt[:, 0] = imt[:, 0] * (1 - imt[:, 1]) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(imt[:, 0] * 2 + imt[:, 1]) dist = torch.distributions.categorical.Categorical(logits=logits) - result = (1 - masks) * imt[:, 0] + masks * dist.sample() + result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample() record.append(result) return torch.cat(record) -- 2.39.5