From: François Fleuret Date: Mon, 28 Nov 2022 23:23:33 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0404f762debe33b6779d1440bc03e09a1c0ac4c4;p=pytorch.git Update. --- diff --git a/stack.py b/stack.py index 453e9ed..546b17b 100755 --- a/stack.py +++ b/stack.py @@ -5,18 +5,16 @@ # Written by Francois Fleuret -from torch import is_tensor +from torch import is_tensor, Tensor import sys def exception_hook(exc_type, exc_value, tb): -# tb = tb.tb_next + repr_orig=Tensor.__repr__ + Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}' while tb: - # x=tb.tb_frame #.f_code - # for field in dir(x): - # print(f'@@@ {field} {getattr(x, field)}') print('--------------------------------------------------') filename = tb.tb_frame.f_code.co_filename name = tb.tb_frame.f_code.co_name @@ -24,16 +22,14 @@ def exception_hook(exc_type, exc_value, tb): print(f' File "{filename}", line {line_no}, in {name}') print(open(filename, 'r').readlines()[line_no-1], end='') - local_vars = tb.tb_frame.f_locals - - for n,v in local_vars.items(): - if is_tensor(v): - print(f' {n} -> {tuple(v.size())}:{v.dtype}:{v.device}') - else: + if exc_type is RuntimeError: + for n,v in tb.tb_frame.f_locals.items(): print(f' {n} -> {v}') tb = tb.tb_next + Tensor.__repr__=repr_orig + print(f'{exc_type.__name__}: {exc_value}') sys.excepthook = exception_hook @@ -56,4 +52,3 @@ if __name__ == '__main__': #print(xxx@mmm) blah(mmm,xxx) blah(xxx,mmm) -