Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 13:12:40 +0000 (15:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 13:12:40 +0000 (15:12 +0200)
main.py

diff --git a/main.py b/main.py
index a353868..51e0fa2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -433,12 +433,12 @@ def predict(model, imt_set, local_device=main_device):
     ):
         # some paranoia
         imt = imt.clone()
-        imt[:, 0] = imt[:, 0] * (1 - imt[:1])
+        imt[:, 0] = imt[:, 0] * (1 - imt[:1])
 
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
             logits = model(imt[:, 0] * 2 + imt[:, 1])
         dist = torch.distributions.categorical.Categorical(logits=logits)
-        result = (1 - masks) * imt[:, 0] + masks * dist.sample()
+        result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
         record.append(result)
 
     return torch.cat(record)