diff --git a/demo/requirements.txt b/demo/requirements.txt index 3ee9ebd..f6f1022 100644 --- a/demo/requirements.txt +++ b/demo/requirements.txt @@ -1,2 +1,2 @@ raidionicsrads@git+https://github.com/dbouget/raidionics_rads_lib -gradio==3.50.2 +gradio==4.29.0 diff --git a/demo/src/css_style.py b/demo/src/css_style.py index f6a7d7a..2573f56 100644 --- a/demo/src/css_style.py +++ b/demo/src/css_style.py @@ -9,8 +9,12 @@ #upload { height: 110px; } +#download { +height: 47px; +width: 150px; +} #run-button { -height: 110px; +height: 47px; width: 150px; } #toggle-button { diff --git a/demo/src/gui.py b/demo/src/gui.py index 74df10d..2686eaf 100644 --- a/demo/src/gui.py +++ b/demo/src/gui.py @@ -59,7 +59,8 @@ def __init__( visible=True, elem_id="model-3d", camera_position=[90, 180, 768], - ).style(height=512) + height=512, + ) def set_class_name(self, value): LOGGER.info(f"Changed task to: {value}") @@ -75,30 +76,44 @@ def upload_file(self, file): def process(self, mesh_file_name): path = mesh_file_name.name + curr = path.split("/")[-1] + self.extension = ".".join(curr.split(".")[1:]) + self.filename = ( + curr.split(".")[0] + "-" + self.class_names[self.class_name] + ) run_model( path, model_path=os.path.join(self.cwd, "resources/models/"), task=self.class_names[self.class_name], name=self.result_names[self.class_name], + output_filename=self.filename + "." + self.extension, ) LOGGER.info("Converting prediction NIfTI to OBJ...") - nifti_to_obj("prediction.nii.gz") + nifti_to_obj(path=self.filename + "." + self.extension) LOGGER.info("Loading CT to numpy...") self.images = load_ct_to_numpy(path) LOGGER.info("Loading prediction volume to numpy..") - self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz") + self.pred_images = load_pred_volume_to_numpy( + self.filename + "." + self.extension + ) return "./prediction.obj" + def download_prediction(self): + if (not self.filename) or (not self.extension): + LOGGER.error( + "The prediction is not available or ready to download. Wait until the result is available in the 3D viewer." + ) + return self.filename + "." + self.extension + def get_img_pred_pair(self, k): k = int(k) out = gr.AnnotatedImage( self.combine_ct_and_seg(self.images[k], self.pred_images[k]), visible=True, elem_id="model-2d", - ).style( color_map={self.class_name: "#ffae00"}, height=512, width=512, @@ -117,20 +132,18 @@ def run(self): placeholder="\n" * 16, label="Logs", info="Verbose from inference will be displayed below.", - lines=38, - max_lines=38, + lines=36, + max_lines=36, autoscroll=True, elem_id="logs", show_copy_button=True, - scroll_to_output=False, container=True, - line_breaks=True, ) demo.load(read_logs, None, logs, every=1) with gr.Column(): with gr.Row(): - with gr.Column(scale=0.2, min_width=150): + with gr.Column(scale=1, min_width=150): sidebar_state = gr.State(True) btn_toggle_sidebar = gr.Button( @@ -149,7 +162,9 @@ def run(self): btn_clear_logs.click(flush_logs, [], []) file_output = gr.File( - file_count="single", elem_id="upload" + file_count="single", + elem_id="upload", + scale=3, ) file_output.upload( self.upload_file, file_output, file_output @@ -160,7 +175,7 @@ def run(self): label="Task", info="Which structure to segment.", multiselect=False, - size="sm", + scale=1, ) model_selector.input( fn=lambda x: self.set_class_name(x), @@ -168,14 +183,11 @@ def run(self): outputs=None, ) - with gr.Column(scale=0.2, min_width=150): + with gr.Column(scale=1, min_width=150): run_btn = gr.Button( "Run analysis", variant="primary", elem_id="run-button", - ).style( - full_width=False, - size="lg", ) run_btn.click( fn=lambda x: self.process(x), @@ -183,6 +195,18 @@ def run(self): outputs=self.volume_renderer, ) + download_btn = gr.DownloadButton( + "Download prediction", + visible=True, + variant="secondary", + elem_id="download", + ) + download_btn.click( + fn=self.download_prediction, + inputs=None, + outputs=download_btn, + ) + with gr.Row(): gr.Examples( examples=[ @@ -202,17 +226,16 @@ def run(self): ) with gr.Row(): - with gr.Box(): + with gr.Group(): with gr.Column(): # create dummy image to be replaced by loaded images t = gr.AnnotatedImage( - visible=True, elem_id="model-2d" - ).style( + visible=True, + elem_id="model-2d", color_map={self.class_name: "#ffae00"}, - height=512, - width=512, + # height=512, + # width=512, ) - self.slider.input( self.get_img_pred_pair, self.slider, @@ -221,7 +244,7 @@ def run(self): self.slider.render() - with gr.Box(): + with gr.Group(): # gr.Box(): self.volume_renderer.render() # sharing app publicly -> share=True: diff --git a/demo/src/inference.py b/demo/src/inference.py index 5d3774c..34e495d 100644 --- a/demo/src/inference.py +++ b/demo/src/inference.py @@ -11,6 +11,7 @@ def run_model( verbose: str = "info", task: str = "CT_Airways", name: str = "Airways", + output_filename: str = None, ): if verbose == "debug": logging.getLogger().setLevel(logging.DEBUG) @@ -27,6 +28,9 @@ def run_model( if os.path.exists("./result/"): shutil.rmtree("./result/") + if output_filename is None: + raise ValueError("Please, set output_filename.") + patient_directory = "" output_path = "" try: @@ -84,7 +88,7 @@ def run_model( + "-t1gd_annotation-" + name + ".nii.gz", - "./prediction.nii.gz", + output_filename, ) # Clean-up if os.path.exists(patient_directory):