From 9e7f263df9b4e75126ab1e39dfb29757061f5583 Mon Sep 17 00:00:00 2001 From: Tyron Jung Date: Tue, 19 Nov 2024 10:01:15 -0800 Subject: [PATCH] Fix name/display_name discrepancy --- truss-chains/tests/test_utils.py | 2 ++ truss-chains/truss_chains/code_gen.py | 1 + truss-chains/truss_chains/definitions.py | 1 + truss-chains/truss_chains/public_api.py | 1 + truss-chains/truss_chains/remote.py | 16 +++++++++++++--- truss-chains/truss_chains/stub.py | 5 ++++- truss-chains/truss_chains/utils.py | 9 ++++++--- 7 files changed, 28 insertions(+), 7 deletions(-) diff --git a/truss-chains/tests/test_utils.py b/truss-chains/tests/test_utils.py index 03fde8f60..cf1e933b5 100644 --- a/truss-chains/tests/test_utils.py +++ b/truss-chains/tests/test_utils.py @@ -28,6 +28,7 @@ def test_populate_chainlet_service_predict_urls(tmp_path, dynamic_config_mount_d chainlet_to_service = { "HelloWorld": definitions.ServiceDescriptor( name="HelloWorld", + display_name="HelloWorld", options=definitions.RPCOptions(), ) } @@ -54,6 +55,7 @@ def test_no_populate_chainlet_service_predict_urls( chainlet_to_service = { "RandInt": definitions.ServiceDescriptor( name="RandInt", + display_name="RandInt", options=definitions.RPCOptions(), ) } diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 62f820425..832e7c524 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -632,6 +632,7 @@ def gen_truss_chainlet( for dep in chainlet_descriptor.dependencies.values(): dep_services[dep.name] = definitions.ServiceDescriptor( name=dep.name, + display_name=dep.display_name, options=dep.options, ) chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 239147057..36a36b530 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -396,6 +396,7 @@ class ServiceDescriptor(SafeModel): specifically with ``StubBase``.""" name: str + display_name: str options: RPCOptions diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 210516bf3..ec95df886 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -197,6 +197,7 @@ class HelloWorld(chains.ChainletBase): chainlet_to_service={ "SomeChainlet": chains.DeployedServiceDescriptor( name="SomeChainlet", + display_name="SomeChainlet", predict_url="https://...", options=chains.RPCOptions(), ) diff --git a/truss-chains/truss_chains/remote.py b/truss-chains/truss_chains/remote.py index 5a99bb6b9..b9ad90cfc 100644 --- a/truss-chains/truss_chains/remote.py +++ b/truss-chains/truss_chains/remote.py @@ -331,13 +331,23 @@ def generate_chainlet_artifacts( chain_root = _get_chain_root(entrypoint, non_entrypoint_root_dir) entrypoint_artifact: Optional[b10_types.ChainletArtifact] = None dependency_artifacts: list[b10_types.ChainletArtifact] = [] + chainlet_display_names: set[str] = set() for chainlet_descriptor in _get_ordered_dependencies([entrypoint]): - model_base_name = chainlet_descriptor.display_name + chainlet_display_name = chainlet_descriptor.display_name + + if chainlet_display_name in chainlet_display_names: + raise definitions.ChainsUsageError( + f"Chainlet names must be unique. Found multiple Chainlets with the name: '{chainlet_display_name}'." + ) + + chainlet_display_names.add(chainlet_display_name) + # Since we are creating a distinct model for each deployment of the chain, # we add a random suffix. model_suffix = str(uuid.uuid4()).split("-")[0] - model_name = f"{model_base_name}-{model_suffix}" + model_name = f"{chainlet_display_name}-{model_suffix}" + chainlet_dir = code_gen.gen_truss_chainlet( chain_root, self._gen_root, @@ -349,7 +359,7 @@ def generate_chainlet_artifacts( artifact = b10_types.ChainletArtifact( truss_dir=chainlet_dir, name=chainlet_descriptor.name, - display_name=chainlet_descriptor.display_name, + display_name=chainlet_display_name, ) is_entrypoint = chainlet_descriptor.chainlet_cls == entrypoint diff --git a/truss-chains/truss_chains/stub.py b/truss-chains/truss_chains/stub.py index 3d7679fd9..6e0927a30 100644 --- a/truss-chains/truss_chains/stub.py +++ b/truss-chains/truss_chains/stub.py @@ -248,7 +248,10 @@ def from_url( options = options or definitions.RPCOptions() return cls( service_descriptor=definitions.DeployedServiceDescriptor( - name=cls.__name__, predict_url=predict_url, options=options + name=cls.__name__, + display_name=cls.__name__, + predict_url=predict_url, + options=options, ), api_key=context.get_baseten_api_key(), ) diff --git a/truss-chains/truss_chains/utils.py b/truss-chains/truss_chains/utils.py index 773a0be76..d78b64586 100644 --- a/truss-chains/truss_chains/utils.py +++ b/truss-chains/truss_chains/utils.py @@ -161,16 +161,19 @@ def populate_chainlet_service_predict_urls( chainlet_name, service_descriptor, ) in chainlet_to_service.items(): - if chainlet_name not in dynamic_chainlet_config: + display_name = service_descriptor.display_name + + if display_name not in dynamic_chainlet_config: raise definitions.MissingDependencyError( - f"Chainlet '{chainlet_name}' not found in '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}'." + f"Chainlet '{display_name}' not found in '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}'. Dynamic Chainlet config keys: {list(dynamic_chainlet_config)}." ) chainlet_to_deployed_service[chainlet_name] = ( definitions.DeployedServiceDescriptor( + display_name=display_name, name=service_descriptor.name, options=service_descriptor.options, - predict_url=dynamic_chainlet_config[chainlet_name]["predict_url"], + predict_url=dynamic_chainlet_config[display_name]["predict_url"], ) )