From: Francois Fleuret Date: Sun, 25 Jun 2017 07:54:19 +0000 (+0200) Subject: Making an even deeper model. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b3c335857859d457575128690e4aa77f52d17e5c;p=pysvrt.git Making an even deeper model. --- diff --git a/cnn-svrt.py b/cnn-svrt.py index 4481049..cb94184 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -264,6 +264,61 @@ class DeepNet2(nn.Module): ###################################################################### +class DeepNet3(nn.Module): + name = 'deepnet3' + + def __init__(self): + super(DeepNet2, self).__init__() + self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) + self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.fc1 = nn.Linear(4096, 512) + self.fc2 = nn.Linear(512, 512) + self.fc3 = nn.Linear(512, 2) + + def forward(self, x): + x = self.conv1(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv2(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv3(x) + x = fn.relu(x) + + x = self.conv4(x) + x = fn.relu(x) + + x = self.conv5(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv6(x) + x = fn.relu(x) + + x = self.conv7(x) + x = fn.relu(x) + + x = x.view(-1, 4096) + + x = self.fc1(x) + x = fn.relu(x) + + x = self.fc2(x) + x = fn.relu(x) + + x = self.fc3(x) + + return x + +###################################################################### + def nb_errors(model, data_set): ne = 0 for b in range(0, data_set.nb_batches): @@ -385,7 +440,7 @@ else: ######################################## model_class = None -for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]: +for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]: if args.model == m.name: model_class = m break @@ -415,7 +470,6 @@ for problem_number in map(int, args.problems.split(',')): ################################################## # Tries to load the model - need_to_train = False try: model_state_dict, nb_epochs_done = torch.load(model_filename) model.load_state_dict(model_state_dict) @@ -469,7 +523,7 @@ for problem_number in map(int, args.problems.split(',')): ################################################## # Test if necessary - if need_to_train or args.test_loaded_models: + if nb_epochs_done < args.nb_epochs or args.test_loaded_models: t = time.time()