From: François Fleuret Date: Tue, 25 Jun 2024 16:16:44 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=e2c3b8046c3fddef8aacb74cf5f848d42044897e;p=culture.git Update. --- diff --git a/main.py b/main.py index 05c3557..524715a 100755 --- a/main.py +++ b/main.py @@ -12,7 +12,8 @@ from torch import nn from torch.nn import functional as F import ffutils -import mygpt, quizz_machine +import mygpt +import sky, quizz_machine # world quizzes vs. culture quizzes @@ -210,6 +211,7 @@ assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 quizz_machine = quizz_machine.QuizzMachine( + sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, @@ -390,7 +392,7 @@ def create_c_quizzes( quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - quizz_machine.save_quizzes( + quizz_machine.problem.save_quizzes( new_c_quizzes[:72], args.result_dir, f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}", diff --git a/problem.py b/problem.py new file mode 100755 index 0000000..25ffc49 --- /dev/null +++ b/problem.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + + +class Problem: + def generate_seq(self, nb): + pass + + def save_quizzes(self, input, result_dir, filename_prefix, logger): + pass + + def direction_tokens(self): + pass diff --git a/quizz_machine.py b/quizz_machine.py index 28b94d1..d63855c 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -66,8 +66,6 @@ def masked_inplace_autoregression( ###################################################################### -import sky - class QuizzMachine: def make_ar_mask(self, input): @@ -76,6 +74,7 @@ class QuizzMachine: def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -85,7 +84,7 @@ class QuizzMachine: ): super().__init__() - self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2) + self.problem = problem self.batch_size = batch_size self.device = device @@ -267,17 +266,15 @@ class QuizzMachine: ave_seq_logproba = seq_logproba.mean() - logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}") - if min_ave_seq_logproba is None: break # Oh man that's ugly - if ave_seq_logproba < min_ave_seq_logproba * 1.1: + if ave_seq_logproba < min_ave_seq_logproba: if d_temperature > 0: d_temperature *= -1 / 3 temperature += d_temperature - elif ave_seq_logproba > min_ave_seq_logproba: + elif ave_seq_logproba > min_ave_seq_logproba * 0.99: if d_temperature < 0: d_temperature *= -1 / 3 temperature += d_temperature diff --git a/sky.py b/sky.py index cb25ea0..ec476a6 100755 --- a/sky.py +++ b/sky.py @@ -14,19 +14,10 @@ from torch.nn import functional as F ###################################################################### +import problem -class Problem: - def generate_seq(self, nb_train_samples): - pass - def save_quizzes(self, input, result_dir, filename_prefix, logger): - pass - - def direction_tokens(self): - pass - - -class Sky: +class Sky(problem.Problem): colors = torch.tensor( [ [255, 255, 255],