From: François Fleuret Date: Mon, 19 Dec 2022 20:19:29 +0000 (+0100) Subject: Cosmetics. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=eea23df18f107fc65c810261c7775a9393ef7c8e;p=picoclvr.git Cosmetics. --- diff --git a/main.py b/main.py index 6d9f69d..c01cc8f 100755 --- a/main.py +++ b/main.py @@ -20,7 +20,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasonning task." + description="An implementation of GPT with cache to solve a toy geometric reasoning task." ) parser.add_argument("--log_filename", type=str, default="train.log") @@ -421,9 +421,7 @@ class TaskPicoCLVR(Task): f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" ) - img = picoclvr.descr2img( - result_descr, [0], height=self.height, width=self.width - ) + img = picoclvr.descr2img(result_descr, height=self.height, width=self.width) if img.dim() == 5: if img.size(1) == 1: diff --git a/picoclvr.py b/picoclvr.py index cc937af..bd0470f 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -241,15 +241,9 @@ def generate( # Extracts the image after in descr as a 1x3xHxW tensor -def descr2img(descr, n, height, width): +def descr2img(descr, height, width): - if type(descr) == list: - return torch.cat([descr2img(d, n, height, width) for d in descr], 0) - - if type(n) == list: - return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze( - 0 - ) + result = [] def token2color(t): try: @@ -257,15 +251,15 @@ def descr2img(descr, n, height, width): except KeyError: return [128, 128, 128] - d = descr.split("") - d = d[n + 1] if len(d) > n + 1 else "" - d = d.strip().split(" ")[: height * width] - d = d + [""] * (height * width - len(d)) - d = [token2color(t) for t in d] - img = torch.tensor(d).permute(1, 0) - img = img.reshape(1, 3, height, width) + for d in descr: + d = d.split("")[1] + d = d.strip().split(" ")[: height * width] + d = d + [""] * (height * width - len(d)) + d = [token2color(t) for t in d] + img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width) + result.append(img) - return img + return torch.cat(result, 0) ###################################################################### @@ -353,7 +347,7 @@ if __name__ == "__main__": for d in descr: f.write(f"{d}\n\n") - img = descr2img(descr, n=0, height=12, width=16) + img = descr2img(descr, height=12, width=16) if img.size(0) == 1: img = F.pad(img, (1, 1, 1, 1), value=64)