From: Francois Fleuret Date: Thu, 3 Sep 2020 04:45:19 +0000 (+0200) Subject: Measures for FP16 and FP32 X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ccc4799245e235f6048681df85c5ef71a32c3c1b;p=pytorch.git Measures for FP16 and FP32 --- diff --git a/speed.py b/speed.py index e4add26..9e845db 100755 --- a/speed.py +++ b/speed.py @@ -12,21 +12,22 @@ else: nb_runs = 10000 d1, d2, d3 = 2048, 2048, 2048 -a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device) +for t in [ torch.float32, torch.float16 ]: + a = torch.rand(d1, d2, device = device, dtype = t) + b = torch.rand(d2, d3, device = device, dtype = t) -sync() -start_time = time.perf_counter() -for k in range(nb_runs): - c = a @ b -sync() -duration = time.perf_counter() - start_time + sync() + start_time = time.perf_counter() + for k in range(nb_runs): + c = a @ b + sync() + duration = time.perf_counter() - start_time -nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops -speed = nb_flop / duration + nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops + speed = nb_flop / duration -for u in [ '', 'K', 'M', 'G', 'T', 'P' ]: - if speed < 1e3: break - speed /= 1e3 - -print(f'{speed:.02f} {u}flops on {device}') + for u in [ '', 'K', 'M', 'G', 'T', 'P' ]: + if speed < 1e3: break + speed /= 1e3 + print(f'{speed:.02f} {u}flops with {t} on {device}')