From: Francois Fleuret Date: Thu, 3 Sep 2020 06:18:03 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=20285925e51c7adc6e7bb64bb9d2a5cab92c6aac;p=pytorch.git Update. --- diff --git a/speed.py b/speed.py index e03b3b7..e5b0e3a 100755 --- a/speed.py +++ b/speed.py @@ -9,30 +9,37 @@ import time, torch if torch.cuda.is_available(): device = torch.device('cuda') - sync = lambda: torch.cuda.synchronize() + sync = torch.cuda.synchronize else: device = torch.device('cpu') sync = lambda: None -nb_runs = 10000 +max_duration = 30 d1, d2, d3 = 2048, 2048, 2048 for t in [ torch.float32, torch.float16 ]: - a = torch.rand(d1, d2, device = device, dtype = t) - b = torch.rand(d2, d3, device = device, dtype = t) + try: + a = torch.rand(d1, d2, device = device, dtype = t) + b = torch.rand(d2, d3, device = device, dtype = t) + nb_runs = 0 - sync() - start_time = time.perf_counter() - for k in range(nb_runs): - c = a @ b - sync() - duration = time.perf_counter() - start_time + sync() + start_time = time.perf_counter() + while time.perf_counter() - start_time < max_duration: + c = a @ b + nb_runs += 1 + sync() + duration = time.perf_counter() - start_time - nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops - speed = nb_flop / duration + nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops + speed = nb_flop / duration - for u in [ '', 'K', 'M', 'G', 'T', 'P' ]: - if speed < 1e3: break - speed /= 1e3 + for u in [ '', 'K', 'M', 'G', 'T', 'P' ]: + if speed < 1e3: break + speed /= 1e3 - print(f'{speed:.02f} {u}flops with {t} on {device}') + print(f'{speed:.02f} {u}flops with {t} on {device}') + + except: + + print(f'Cannot try with {t}')