Skip to content

Commit

Permalink
Fix preset table doc links (#1995)
Browse files Browse the repository at this point in the history
We had relative links that broke when we changed the site structure,
switch to absolute links.

Also:
- Better sort order for preset table.
- Use less metadata in the table.
  • Loading branch information
mattdangerw authored Nov 26, 2024
1 parent 05faae3 commit ede75bc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 45 deletions.
2 changes: 1 addition & 1 deletion scripts/autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
GUIDES_GH_LOCATION = Path("keras-team") / "keras-io" / "blob" / "master" / "guides"
KERAS_TEAM_GH = "https://github.com/keras-team"
PROJECT_URL = {
"keras": f"{KERAS_TEAM_GH}/keras/tree/v3.6.0/",
"keras": f"{KERAS_TEAM_GH}/keras/tree/v3.7.0/",
"keras_tuner": f"{KERAS_TEAM_GH}/keras-tuner/tree/v1.4.7/",
"keras_hub": f"{KERAS_TEAM_GH}/keras-hub/tree/v0.17.0/",
"tf_keras": f"{KERAS_TEAM_GH}/tf-keras/tree/v2.18.0/",
Expand Down
92 changes: 48 additions & 44 deletions scripts/render_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
}
"""

from hub_master import MODELS_MASTER

try:
import keras_hub
except Exception as e:
Expand Down Expand Up @@ -46,10 +48,18 @@ def format_param_count(metadata):

def format_path(metadata):
"""Returns Path for the given preset"""
try:
return f"[{metadata['official_name']}]({metadata['path']})"
except KeyError:
return "Unknown"
for child in MODELS_MASTER["children"]:
path = child["path"].strip("/")
if metadata["path"] == path:
text = child["title"]
link = f"/keras_hub/api/models/{path}"
return f"[{text}]({link})"
return "-"


def format_preset_link(preset, handle):
url = handle.replace("kaggle://", "https://www.kaggle.com/models/")
return f"[{preset}]({url})"


def is_base_class(symbol):
Expand All @@ -61,35 +71,38 @@ def is_base_class(symbol):
)


def render_all_presets(symbols):
"""Renders the markdown table for backbone presets as a string."""
def sort_presets(presets):
# Sort by path and then by parameter count.
return sorted(
presets.keys(),
key=lambda x: (
presets[x]["metadata"]["path"],
presets[x]["metadata"]["params"],
)
)


def render_row(preset, data, add_doc_link=False):
"""Renders a row for a preset in a markdown table."""
metadata = data["metadata"]
url = data["kaggle_handle"]
url = url.replace("kaggle://", "https://www.kaggle.com/models/")
cols = []
cols.append(format_preset_link(preset, data["kaggle_handle"]))
if add_doc_link:
cols.append(format_path(metadata))
cols.append(format_param_count(metadata))
cols.append(metadata["description"])
return " | ".join(cols) + "\n"

table = TABLE_HEADER

# Backbones has alias, which duplicates some presets.
# Use a set to keep them unique.
added_presets = set()
# Bakcbone presets
for name, symbol in symbols:
if is_base_class(symbol) or "Backbone" not in name:
continue
presets = symbol.presets
# Only keep the ones with pretrained weights for KerasCV Backbones.
for preset in presets:
if preset in added_presets:
continue
else:
added_presets.add(preset)
metadata = presets[preset]["metadata"]
url = presets[preset]["kaggle_handle"]
url = url.replace("kaggle://", "https://www.kaggle.com/models/")
table += (
f"[{preset}]({url}) | "
f"{format_path(metadata)} | "
f"{format_param_count(metadata)} | "
f"{metadata['description']}"
)
table += "\n"
def render_all_presets():
"""Renders the markdown table for backbone presets as a string."""
table = TABLE_HEADER
symbol = keras_hub.models.Backbone
for preset in sort_presets(symbol.presets):
data = symbol.presets[preset]
table += render_row(preset, data, add_doc_link=True)
return table


Expand All @@ -100,15 +113,9 @@ def render_table(symbol):
table = TABLE_HEADER_PER_MODEL
if is_base_class(symbol) or len(symbol.presets) == 0:
return None
for preset in symbol.presets:
metadata = symbol.presets[preset]["metadata"]
url = symbol.presets[preset]["kaggle_handle"]
url = url.replace("kaggle://", "https://www.kaggle.com/models/")
table += (
f"[{preset}]({url}) | "
f"{format_param_count(metadata)} | "
f"{metadata['description']} \n"
)
for preset in sort_presets(symbol.presets):
data = symbol.presets[preset]
table += render_row(preset, data)
return table


Expand All @@ -117,9 +124,6 @@ def render_tags(template):
if keras_hub is None:
return template

symbols = keras_hub.models.__dict__.items()
if "{{presets_table}}" in template:
template = template.replace(
"{{presets_table}}", render_all_presets(symbols)
)
template = template.replace("{{presets_table}}", render_all_presets())
return template

0 comments on commit ede75bc

Please sign in to comment.