2323SNAPSHOT_NAME = "snapshot-700000.pb"
2424TMP_DIR = Path (__file__ ).parent / "dlc-live-tmp"
2525
26- MODELS_FOLDER = TMP_DIR / "test_models"
26+ MODELS_DIR = TMP_DIR / "test_models"
2727TORCH_MODEL = "resnet_50"
2828TORCH_CONFIG = {
29- "checkpoint" : MODELS_FOLDER / f"exported_quadruped_{ TORCH_MODEL } .pt" ,
29+ "checkpoint" : MODELS_DIR / f"exported_quadruped_{ TORCH_MODEL } .pt" ,
3030 "super_animal" : "superanimal_quadruped" ,
3131}
32- TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
32+ TF_MODEL_DIR = MODELS_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
3333
3434
3535def run_pytorch_test (video_file : str , display : bool = False ):
3636 from dlclive .modelzoo .pytorch_model_zoo_export import export_modelzoo_model
3737
3838 if Engine .PYTORCH not in get_available_backends ():
39- raise NotImplementedError (
39+ raise ImportError (
4040 "PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
4141 )
4242 # Download model from the DeepLabCut Model Zoo
@@ -64,7 +64,7 @@ def run_pytorch_test(video_file: str, display: bool = False):
6464
6565def run_tensorflow_test (video_file : str , display : bool = False ):
6666 if Engine .TENSORFLOW not in get_available_backends ():
67- raise NotImplementedError (
67+ raise ImportError (
6868 "TensorFlow backend is not available. Please ensure TensorFlow is installed to run the TensorFlow test."
6969 )
7070 model_dir = TF_MODEL_DIR
@@ -93,38 +93,39 @@ def run_tensorflow_test(video_file: str, display: bool = False):
9393
9494
9595def main ():
96- tmp_dir = None
97- try :
98- parser = argparse .ArgumentParser (
99- description = "Test DLC-Live installation by downloading and evaluating a demo DLC project!"
100- )
101- parser .add_argument (
102- "--display" ,
103- action = "store_true" ,
104- default = False ,
105- help = "Run the test and display tracking" ,
106- )
107- parser .add_argument (
108- "--nodisplay" ,
109- action = "store_false" ,
110- dest = "display" ,
111- help = argparse .SUPPRESS ,
112- )
96+ backend_results = {}
97+
98+ parser = argparse .ArgumentParser (
99+ description = "Test DLC-Live installation by downloading and evaluating a demo DLC project!"
100+ )
101+ parser .add_argument (
102+ "--display" ,
103+ action = "store_true" ,
104+ default = False ,
105+ help = "Run the test and display tracking" ,
106+ )
107+ parser .add_argument (
108+ "--nodisplay" ,
109+ action = "store_false" ,
110+ dest = "display" ,
111+ help = argparse .SUPPRESS ,
112+ )
113113
114- args = parser .parse_args ()
115- display = args .display
114+ args = parser .parse_args ()
115+ display = args .display
116116
117- if not display :
118- print ("Running without displaying video" )
117+ if not display :
118+ print ("Running without displaying video" )
119119
120- # make temporary directory
121- print ("\n Creating temporary directory...\n " )
122- tmp_dir = TMP_DIR
123- tmp_dir .mkdir (mode = 0o775 , exist_ok = True )
124- MODELS_FOLDER .mkdir (parents = True , exist_ok = True )
120+ # make temporary directory
121+ print ("\n Creating temporary directory...\n " )
122+ tmp_dir = TMP_DIR
123+ tmp_dir .mkdir (mode = 0o775 , exist_ok = True )
124+ MODELS_DIR .mkdir (parents = True , exist_ok = True )
125125
126- video_file = str (tmp_dir / "dog_clip.avi" )
126+ video_file = str (tmp_dir / "dog_clip.avi" )
127127
128+ try :
128129 # download dog test video from github:
129130 # Use raw.githubusercontent.com for direct file access
130131 if not Path (video_file ).exists ():
@@ -148,33 +149,51 @@ def main():
148149 print ("\n Running PyTorch test...\n " )
149150 run_pytorch_test (video_file , display = display )
150151 any_backend_succeeded = True
152+ backend_results ["pytorch" ] = ("SUCCESS" , None )
151153 elif backend == Engine .TENSORFLOW :
152154 print ("\n Running TensorFlow test...\n " )
153155 run_tensorflow_test (video_file , display = display )
154156 any_backend_succeeded = True
157+ backend_results ["tensorflow" ] = ("SUCCESS" , None )
155158 else :
156159 warnings .warn (
157160 f"Unrecognized backend { backend } , skipping..." , UserWarning
158161 )
159162 except Exception as e :
163+ backend_name = (
164+ "pytorch" if backend == Engine .PYTORCH else
165+ "tensorflow" if backend == Engine .TENSORFLOW else
166+ str (backend )
167+ )
168+ backend_results [backend_name ] = ("ERROR" , str (e ))
160169 backend_failures [backend ] = e
161170 warnings .warn (
162171 f"Error while running test for backend { backend } : { e } . "
163172 "Continuing to test other available backends." ,
164173 UserWarning ,
165174 )
166175
167- if not any_backend_succeeded and backend_failures :
168- failure_messages = "; " .join (
169- f"{ b } : { exc } " for b , exc in backend_failures .items ()
170- )
171- raise RuntimeError (f"All backend tests failed. Details: { failure_messages } " )
176+ print ("\n ---\n Backend test summary:" )
177+ for name in ("tensorflow" , "pytorch" ):
178+ status , _ = backend_results .get (name , ("SKIPPED" , None ))
179+ print (f"{ name :<11} [{ status } ]" )
180+ print ("---" )
181+ for name , (status , error ) in backend_results .items ():
182+ if status == "ERROR" :
183+ print (f"{ name .capitalize ()} error:\n { error } \n " )
184+
185+ if not any_backend_succeeded and backend_failures :
186+ failure_messages = "; " .join (
187+ f"{ b } : { exc } " for b , exc in backend_failures .items ()
188+ )
189+ raise RuntimeError (f"All backend tests failed. Details: { failure_messages } " )
190+
172191
173192 finally :
174193 # deleting temporary files
175194 print ("\n Deleting temporary files...\n " )
176195 try :
177- if tmp_dir is not None and tmp_dir .exists ():
196+ if tmp_dir .exists ():
178197 shutil .rmtree (tmp_dir )
179198 except PermissionError :
180199 warnings .warn (
0 commit comments