Skip to content

Commit

Permalink
Fix name/display_name discrepancy
Browse files Browse the repository at this point in the history
  • Loading branch information
tyranitar committed Nov 19, 2024
1 parent 9e8c458 commit 9e7f263
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions truss-chains/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_populate_chainlet_service_predict_urls(tmp_path, dynamic_config_mount_d
chainlet_to_service = {
"HelloWorld": definitions.ServiceDescriptor(
name="HelloWorld",
display_name="HelloWorld",
options=definitions.RPCOptions(),
)
}
Expand All @@ -54,6 +55,7 @@ def test_no_populate_chainlet_service_predict_urls(
chainlet_to_service = {
"RandInt": definitions.ServiceDescriptor(
name="RandInt",
display_name="RandInt",
options=definitions.RPCOptions(),
)
}
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def gen_truss_chainlet(
for dep in chainlet_descriptor.dependencies.values():
dep_services[dep.name] = definitions.ServiceDescriptor(
name=dep.name,
display_name=dep.display_name,
options=dep.options,
)
chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class ServiceDescriptor(SafeModel):
specifically with ``StubBase``."""

name: str
display_name: str
options: RPCOptions


Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class HelloWorld(chains.ChainletBase):
chainlet_to_service={
"SomeChainlet": chains.DeployedServiceDescriptor(
name="SomeChainlet",
display_name="SomeChainlet",
predict_url="https://...",
options=chains.RPCOptions(),
)
Expand Down
16 changes: 13 additions & 3 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,23 @@ def generate_chainlet_artifacts(
chain_root = _get_chain_root(entrypoint, non_entrypoint_root_dir)
entrypoint_artifact: Optional[b10_types.ChainletArtifact] = None
dependency_artifacts: list[b10_types.ChainletArtifact] = []
chainlet_display_names: set[str] = set()

for chainlet_descriptor in _get_ordered_dependencies([entrypoint]):
model_base_name = chainlet_descriptor.display_name
chainlet_display_name = chainlet_descriptor.display_name

if chainlet_display_name in chainlet_display_names:
raise definitions.ChainsUsageError(
f"Chainlet names must be unique. Found multiple Chainlets with the name: '{chainlet_display_name}'."
)

chainlet_display_names.add(chainlet_display_name)

# Since we are creating a distinct model for each deployment of the chain,
# we add a random suffix.
model_suffix = str(uuid.uuid4()).split("-")[0]
model_name = f"{model_base_name}-{model_suffix}"
model_name = f"{chainlet_display_name}-{model_suffix}"

chainlet_dir = code_gen.gen_truss_chainlet(
chain_root,
self._gen_root,
Expand All @@ -349,7 +359,7 @@ def generate_chainlet_artifacts(
artifact = b10_types.ChainletArtifact(
truss_dir=chainlet_dir,
name=chainlet_descriptor.name,
display_name=chainlet_descriptor.display_name,
display_name=chainlet_display_name,
)

is_entrypoint = chainlet_descriptor.chainlet_cls == entrypoint
Expand Down
5 changes: 4 additions & 1 deletion truss-chains/truss_chains/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def from_url(
options = options or definitions.RPCOptions()
return cls(
service_descriptor=definitions.DeployedServiceDescriptor(
name=cls.__name__, predict_url=predict_url, options=options
name=cls.__name__,
display_name=cls.__name__,
predict_url=predict_url,
options=options,
),
api_key=context.get_baseten_api_key(),
)
Expand Down
9 changes: 6 additions & 3 deletions truss-chains/truss_chains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,19 @@ def populate_chainlet_service_predict_urls(
chainlet_name,
service_descriptor,
) in chainlet_to_service.items():
if chainlet_name not in dynamic_chainlet_config:
display_name = service_descriptor.display_name

if display_name not in dynamic_chainlet_config:
raise definitions.MissingDependencyError(
f"Chainlet '{chainlet_name}' not found in '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}'."
f"Chainlet '{display_name}' not found in '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}'. Dynamic Chainlet config keys: {list(dynamic_chainlet_config)}."
)

chainlet_to_deployed_service[chainlet_name] = (
definitions.DeployedServiceDescriptor(
display_name=display_name,
name=service_descriptor.name,
options=service_descriptor.options,
predict_url=dynamic_chainlet_config[chainlet_name]["predict_url"],
predict_url=dynamic_chainlet_config[display_name]["predict_url"],
)
)

Expand Down

0 comments on commit 9e7f263

Please sign in to comment.