parser.add_argument("--sandbox_levels_nb_items", type=int, default=25)
-parser.add_argument("--sandbox_levels_len_source", type=int, default=5)
+parser.add_argument("--sandbox_levels_len_source", type=int, default=6)
parser.add_argument("--sandbox_levels_len_result", type=int, default=8)
default_args = {
"sandbox": {
- "nb_epochs": 10,
+ "nb_epochs": 50,
"batch_size": 25,
- "nb_train_samples": 25000,
+ "nb_train_samples": 100000,
"nb_test_samples": 10000,
},
"picoclvr": {
// 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
) % 10
marker1 = torch.full((nb, 1), 10)
- source = torch.randint(10, (nb, self.len_source))
+ # source = torch.randint(10, (nb, self.len_source))
+ source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
marker2 = torch.full((nb, 1), 11)
result = operators.bmm(source[:, :, None]).squeeze(-1)
print(f"{nb_operators.dtype=} {marker1.dtype=}")
torch.rand(nb, self.len_result, self.len_source).argmax(-1),
num_classes=self.len_source,
)
- source1 = torch.randint(10, (nb, self.len_source))
+ source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+ # source1 = torch.randint(10, (nb, self.len_source))
marker1 = torch.full((nb, 1), 10)
result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
marker2 = torch.full((nb, 1), 11)