Skip to content

Commit

Permalink
Merge pull request #59 from andreped/download-button
Browse files Browse the repository at this point in the history
Download button in demo app; upgraded gradio to latest
  • Loading branch information
andreped authored Jun 21, 2024
2 parents f6ffe33 + 8291c90 commit 35de495
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 25 deletions.
2 changes: 1 addition & 1 deletion demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
raidionicsrads@git+https://github.com/dbouget/raidionics_rads_lib
gradio==3.50.2
gradio==4.29.0
6 changes: 5 additions & 1 deletion demo/src/css_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
#upload {
height: 110px;
}
#download {
height: 47px;
width: 150px;
}
#run-button {
height: 110px;
height: 47px;
width: 150px;
}
#toggle-button {
Expand Down
67 changes: 45 additions & 22 deletions demo/src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(
visible=True,
elem_id="model-3d",
camera_position=[90, 180, 768],
).style(height=512)
height=512,
)

def set_class_name(self, value):
LOGGER.info(f"Changed task to: {value}")
Expand All @@ -75,30 +76,44 @@ def upload_file(self, file):

def process(self, mesh_file_name):
path = mesh_file_name.name
curr = path.split("/")[-1]
self.extension = ".".join(curr.split(".")[1:])
self.filename = (
curr.split(".")[0] + "-" + self.class_names[self.class_name]
)
run_model(
path,
model_path=os.path.join(self.cwd, "resources/models/"),
task=self.class_names[self.class_name],
name=self.result_names[self.class_name],
output_filename=self.filename + "." + self.extension,
)
LOGGER.info("Converting prediction NIfTI to OBJ...")
nifti_to_obj("prediction.nii.gz")
nifti_to_obj(path=self.filename + "." + self.extension)

LOGGER.info("Loading CT to numpy...")
self.images = load_ct_to_numpy(path)

LOGGER.info("Loading prediction volume to numpy..")
self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz")
self.pred_images = load_pred_volume_to_numpy(
self.filename + "." + self.extension
)

return "./prediction.obj"

def download_prediction(self):
if (not self.filename) or (not self.extension):
LOGGER.error(
"The prediction is not available or ready to download. Wait until the result is available in the 3D viewer."
)
return self.filename + "." + self.extension

def get_img_pred_pair(self, k):
k = int(k)
out = gr.AnnotatedImage(
self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
visible=True,
elem_id="model-2d",
).style(
color_map={self.class_name: "#ffae00"},
height=512,
width=512,
Expand All @@ -117,20 +132,18 @@ def run(self):
placeholder="\n" * 16,
label="Logs",
info="Verbose from inference will be displayed below.",
lines=38,
max_lines=38,
lines=36,
max_lines=36,
autoscroll=True,
elem_id="logs",
show_copy_button=True,
scroll_to_output=False,
container=True,
line_breaks=True,
)
demo.load(read_logs, None, logs, every=1)

with gr.Column():
with gr.Row():
with gr.Column(scale=0.2, min_width=150):
with gr.Column(scale=1, min_width=150):
sidebar_state = gr.State(True)

btn_toggle_sidebar = gr.Button(
Expand All @@ -149,7 +162,9 @@ def run(self):
btn_clear_logs.click(flush_logs, [], [])

file_output = gr.File(
file_count="single", elem_id="upload"
file_count="single",
elem_id="upload",
scale=3,
)
file_output.upload(
self.upload_file, file_output, file_output
Expand All @@ -160,29 +175,38 @@ def run(self):
label="Task",
info="Which structure to segment.",
multiselect=False,
size="sm",
scale=1,
)
model_selector.input(
fn=lambda x: self.set_class_name(x),
inputs=model_selector,
outputs=None,
)

with gr.Column(scale=0.2, min_width=150):
with gr.Column(scale=1, min_width=150):
run_btn = gr.Button(
"Run analysis",
variant="primary",
elem_id="run-button",
).style(
full_width=False,
size="lg",
)
run_btn.click(
fn=lambda x: self.process(x),
inputs=file_output,
outputs=self.volume_renderer,
)

download_btn = gr.DownloadButton(
"Download prediction",
visible=True,
variant="secondary",
elem_id="download",
)
download_btn.click(
fn=self.download_prediction,
inputs=None,
outputs=download_btn,
)

with gr.Row():
gr.Examples(
examples=[
Expand All @@ -202,17 +226,16 @@ def run(self):
)

with gr.Row():
with gr.Box():
with gr.Group():
with gr.Column():
# create dummy image to be replaced by loaded images
t = gr.AnnotatedImage(
visible=True, elem_id="model-2d"
).style(
visible=True,
elem_id="model-2d",
color_map={self.class_name: "#ffae00"},
height=512,
width=512,
# height=512,
# width=512,
)

self.slider.input(
self.get_img_pred_pair,
self.slider,
Expand All @@ -221,7 +244,7 @@ def run(self):

self.slider.render()

with gr.Box():
with gr.Group(): # gr.Box():
self.volume_renderer.render()

# sharing app publicly -> share=True:
Expand Down
6 changes: 5 additions & 1 deletion demo/src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def run_model(
verbose: str = "info",
task: str = "CT_Airways",
name: str = "Airways",
output_filename: str = None,
):
if verbose == "debug":
logging.getLogger().setLevel(logging.DEBUG)
Expand All @@ -27,6 +28,9 @@ def run_model(
if os.path.exists("./result/"):
shutil.rmtree("./result/")

if output_filename is None:
raise ValueError("Please, set output_filename.")

patient_directory = ""
output_path = ""
try:
Expand Down Expand Up @@ -84,7 +88,7 @@ def run_model(
+ "-t1gd_annotation-"
+ name
+ ".nii.gz",
"./prediction.nii.gz",
output_filename,
)
# Clean-up
if os.path.exists(patient_directory):
Expand Down

0 comments on commit 35de495

Please sign in to comment.