From: François Fleuret Date: Tue, 17 Sep 2024 13:12:40 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=240c59a123f0ed07c3faa0d02589c1348f41b365;p=culture.git Update. --- diff --git a/main.py b/main.py index a353868..51e0fa2 100755 --- a/main.py +++ b/main.py @@ -433,12 +433,12 @@ def predict(model, imt_set, local_device=main_device): ): # some paranoia imt = imt.clone() - imt[:, 0] = imt[:, 0] * (1 - imt[:1]) + imt[:, 0] = imt[:, 0] * (1 - imt[:, 1]) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(imt[:, 0] * 2 + imt[:, 1]) dist = torch.distributions.categorical.Categorical(logits=logits) - result = (1 - masks) * imt[:, 0] + masks * dist.sample() + result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample() record.append(result) return torch.cat(record)