@@ -85,8 +85,19 @@ def run_tensorflow_test(video_file: str, display: bool = False):
8585 )
8686
8787
88+ BACKEND_TESTS = {
89+ Engine .PYTORCH : run_pytorch_test ,
90+ Engine .TENSORFLOW : run_tensorflow_test ,
91+ }
92+ BACKEND_DISPLAY_NAMES = {
93+ Engine .PYTORCH : "PyTorch" ,
94+ Engine .TENSORFLOW : "TensorFlow" ,
95+ }
96+
97+
8898def main ():
89- backend_results = {}
99+ backend_results : dict [Engine , tuple [str , str | None ]] = {}
100+ backend_failures : dict [Engine , Exception ] = {}
90101
91102 parser = argparse .ArgumentParser (
92103 description = "Test DLC-Live installation by downloading and evaluating a demo DLC project!"
@@ -111,7 +122,6 @@ def main():
111122 if not display :
112123 print ("Running without displaying video" )
113124
114- # make temporary directory
115125 print ("\n Creating temporary directory...\n " )
116126 tmp_dir = TMP_DIR
117127 tmp_dir .mkdir (mode = 0o775 , exist_ok = True )
@@ -132,10 +142,9 @@ def main():
132142 else :
133143 print (f"Video file already exists at { video_file } , skipping download." )
134144
135- # assert these things exist so we can give informative error messages
136145 if not Path (video_file ).exists ():
137146 raise FileNotFoundError (f"Missing video file { video_file } " )
138- backend_failures = {}
147+
139148 any_backend_succeeded = False
140149
141150 available_backends = get_available_backends ()
@@ -147,28 +156,23 @@ def main():
147156 )
148157
149158 for backend in available_backends :
159+ test_func = BACKEND_TESTS .get (backend )
160+ if test_func is None :
161+ warnings .warn (
162+ f"No test function defined for backend { backend } , skipping..." ,
163+ UserWarning ,
164+ stacklevel = 2 ,
165+ )
166+ continue
167+
150168 try :
151- if backend == Engine .PYTORCH :
152- print ("\n Running PyTorch test...\n " )
153- run_pytorch_test (video_file , display = display )
154- any_backend_succeeded = True
155- backend_results ["pytorch" ] = ("SUCCESS" , None )
156- elif backend == Engine .TENSORFLOW :
157- print ("\n Running TensorFlow test...\n " )
158- run_tensorflow_test (video_file , display = display )
159- any_backend_succeeded = True
160- backend_results ["tensorflow" ] = ("SUCCESS" , None )
161- else :
162- warnings .warn (f"Unrecognized backend { backend } , skipping..." , UserWarning , stacklevel = 2 )
169+ print (f"\n Running { BACKEND_DISPLAY_NAMES .get (backend , backend .value )} test...\n " )
170+ test_func (video_file , display = display )
171+ any_backend_succeeded = True
172+ backend_results [backend ] = ("SUCCESS" , None )
173+
163174 except Exception as e :
164- backend_name = (
165- "pytorch"
166- if backend == Engine .PYTORCH
167- else "tensorflow"
168- if backend == Engine .TENSORFLOW
169- else str (backend )
170- )
171- backend_results [backend_name ] = ("ERROR" , str (e ))
175+ backend_results [backend ] = ("ERROR" , str (e ))
172176 backend_failures [backend ] = e
173177 warnings .warn (
174178 f"Error while running test for backend { backend } : { e } . "
@@ -178,16 +182,18 @@ def main():
178182 )
179183
180184 print ("\n ---\n Backend test summary:" )
181- for name in ( "tensorflow" , "pytorch" ):
182- status , _ = backend_results .get (name , ("SKIPPED" , None ))
183- print (f"{ name :<11} [{ status } ]" )
185+ for backend in BACKEND_TESTS . keys ( ):
186+ status , _ = backend_results .get (backend , ("SKIPPED" , None ))
187+ print (f"{ backend . value :<11} [{ status } ]" )
184188 print ("---" )
185- for name , (status , error ) in backend_results .items ():
189+
190+ for backend , (status , error ) in backend_results .items ():
186191 if status == "ERROR" :
187- print (f"{ name .capitalize ()} error:\n { error } \n " )
192+ backend_name = BACKEND_DISPLAY_NAMES .get (backend , backend .value )
193+ print (f"{ backend_name } error:\n { error } \n " )
188194
189195 if not any_backend_succeeded and backend_failures :
190- failure_messages = "; " .join (f"{ b } : { exc } " for b , exc in backend_failures .items ())
196+ failure_messages = "; " .join (f"{ b } : { e } " for b , e in backend_failures .items ())
191197 raise RuntimeError (f"All backend tests failed. Details: { failure_messages } " )
192198
193199 finally :
@@ -198,7 +204,8 @@ def main():
198204 shutil .rmtree (tmp_dir )
199205 except PermissionError :
200206 warnings .warn (
201- f"Could not delete temporary directory { str (tmp_dir )} due to a permissions error." , stacklevel = 2
207+ f"Could not delete temporary directory { str (tmp_dir )} due to a permissions error." ,
208+ stacklevel = 2 ,
202209 )
203210
204211
0 commit comments