From: François Fleuret Date: Sat, 21 Oct 2023 21:07:54 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0d86d8ca945722438d3c85cd01b3740269ed3546;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index 6a1e8c1..6e87cda 100755 --- a/main.py +++ b/main.py @@ -257,7 +257,7 @@ default_task_args = { "degradation": { "model": "37M", "batch_size": 25, - "nb_train_samples": 100000, + "nb_train_samples": 250000, "nb_test_samples": 10000, }, "mnist": { diff --git a/problems.py b/problems.py index 7aa59be..22b6517 100755 --- a/problems.py +++ b/problems.py @@ -22,37 +22,51 @@ class Problem: nb_correct = ((result == input).long() * ar_mask).sum().item() return nb_total, nb_correct + #################### class ProblemDegradation(Problem): - def __init__(self, nb_state_tokens=7, nb_time_steps=10, value_max=100, hard=False): + def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False): + assert value_max // nb_state_tokens >= 2 self.nb_state_tokens = nb_state_tokens self.nb_time_steps = nb_time_steps self.value_max = value_max self.hard = hard - def generate_sequences(self,nb): - - x = (torch.rand(nb,self.nb_state_tokens).sort(dim=-1).indices == 0).long() * self.value_max + def generate_sequences(self, nb): + x = ( + torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0 + ).long() * self.value_max seq = [x] - for t in range(self.nb_time_steps-1): - v = torch.rand(x.size()) * (x > 0).float() - u = (v.max(dim=-1,keepdim=True).values == v).long() - n = (u*x*torch.rand(x.size())).long().sum(dim=-1,keepdim=True) // 2 - x = x + n * (u.roll(shifts=-1,dims=-1) - 2 * u + u.roll(shifts=1,dims=-1)) + for t in range(self.nb_time_steps - 1): + v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long() + u = (v.max(dim=-1, keepdim=True).values == v).long() + n = ( + (u * x) + .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size())) + .sum(dim=-1, keepdim=True) + ) + m = 1 + ((n - 1) * torch.rand(n.size())).long() + x = ( + x + + m * u.roll(shifts=-1, dims=-1) + - n * u + + (n - m) * u.roll(shifts=1, dims=-1) + ) seq.append(x) - if self.hard: seq.reverse() + if self.hard: + seq.reverse() - seq = torch.cat(seq,dim=1) - return seq,seq.new_full(seq.size(), 1, dtype=torch.int64) + seq = torch.cat(seq, dim=1) + return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) def compute_nb_correct(self, input, ar_mask, result): nb_total = result.size(0) nb_correct = 0 - e=result.new_zeros(self.nb_state_tokens) + e = result.new_zeros(self.nb_state_tokens) for seq in result: states = list(seq.split(self.nb_state_tokens)) @@ -60,27 +74,38 @@ class ProblemDegradation(Problem): states.reverse() d = states[0] - j=d.sort(descending=True).indices[0] + j = d.sort(descending=True).indices[0] e.zero_() - e[j]=self.value_max - if (d-e).abs().sum() == 0: + e[j] = self.value_max + if (d - e).abs().sum() == 0: nb_errors = 0 - for k in range(len(states)-1): - d=states[k]-states[k+1] - j=d.sort(descending=True).indices[0] - e.zero_() - e[j]=d[j] - e[(j+1)%e.size(0)]=-d[j]//2 - e[(j-1)%e.size(0)]=-d[j]//2 - if (d-e).abs().sum() > 0: + for k in range(len(states) - 1): + d = states[k + 1] - states[k] + j = d.sort(descending=False).indices[0] + if ( + d[j] == 0 + or d[j] > self.value_max // 4 + or d[(j + 1) % e.size(0)] <= 0 + or d[(j + 1) % e.size(0)] >= -d[j] + ): nb_errors += 1 + else: + e.zero_() + e[j] = d[j] + e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)] + e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j] + if (d - e).abs().sum() > 0: + nb_errors += 1 if nb_errors == 0: nb_correct += 1 return nb_total, nb_correct def seq2str(self, seq): - return " | ".join( [ " ".join([f"{x:02d}" for x in s ]) for s in seq.split(self.nb_state_tokens) ] ) + return " | ".join( + [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)] + ) + #################### diff --git a/qmlp.py b/qmlp.py index b58598a..abebfc1 100755 --- a/qmlp.py +++ b/qmlp.py @@ -92,23 +92,37 @@ def generate_sets_and_params( test_input = dequantize(q_test_input, -1, 1) if save_as_examples: - a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1 - xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1), - a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2) - xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1) + a = ( + 2 + * torch.arange(nb_quantization_levels).float() + / (nb_quantization_levels - 1) + - 1 + ) + xf = torch.cat( + [ + a[:, None, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + a[None, :, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + ], + 2, + ) + xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1) print(f"{xf.size()=} {x.size()=}") yf = ( ( - (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long() - * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long() - * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long() - * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long() + (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long() + * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long() + * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long() + * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long() ) .max(dim=1) .values ) - full_input, full_targets = xf,yf + full_input, full_targets = xf, yf q_full_input = quantize(full_input, -1, 1) full_input = dequantize(q_full_input, -1, 1) @@ -208,8 +222,12 @@ def generate_sets_and_params( def evaluate_q_params( - q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024, - save_as_examples=False, + q_params, + q_set, + batch_size=25, + device=torch.device("cpu"), + nb_mlps_per_batch=1024, + save_as_examples=False, ): errors = [] nb_mlps = q_params.size(0) diff --git a/tasks.py b/tasks.py index 0858282..7a4abbe 100755 --- a/tasks.py +++ b/tasks.py @@ -110,15 +110,20 @@ class SandBox(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - # A bit of paranoia never hurts - assert ( - self.nb_codes <= max_nb_codes - and self.train_input.min() >= 0 - and self.test_input.min() >= 0 - and tuple(x.item() for x in self.train_ar_mask.unique()) in { (0,), (1,), (0,1) } - and tuple(x.item() for x in self.test_ar_mask.unique()) in { (0,), (1,), (0,1) } - ) + assert self.nb_codes <= max_nb_codes + assert self.train_input.min() >= 0 + assert self.test_input.min() >= 0 + assert tuple(x.item() for x in self.train_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } + assert tuple(x.item() for x in self.test_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -152,16 +157,21 @@ class SandBox(Task): device=self.device, ) + log_ground_truth = ar_mask.min() == 0 + if logger is not None: for sp, st in zip(result[:10], input[:10]): logger( f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}" ) - logger( - f" {n_epoch} ground truth {self.problem.seq2str(st)}" - ) + if log_ground_truth: + logger( + f" {n_epoch} ground truth {self.problem.seq2str(st)}" + ) - nb_total, nb_correct = self.problem.compute_nb_correct(input, ar_mask, result) + nb_total, nb_correct = self.problem.compute_nb_correct( + input, ar_mask, result + ) # nb_total = ar_mask.sum().item() # nb_correct = ((result == input).long() * ar_mask).sum().item()