diff --git a/frontend/lib/resolution/resolution-queries.cpp b/frontend/lib/resolution/resolution-queries.cpp index 158828a5ab00..0970e2371d0a 100644 --- a/frontend/lib/resolution/resolution-queries.cpp +++ b/frontend/lib/resolution/resolution-queries.cpp @@ -2150,8 +2150,8 @@ ApplicabilityResult instantiateSignature(ResolutionContext* rc, const TypedFnSignature* parentSignature = sig->parentFn(); if (parentSignature) { for (auto up = parentSignature; up; up = up->parentFn()) { - CHPL_ASSERT(!up->needsInstantiation()); if (up->needsInstantiation()) { + CHPL_UNIMPL("parent function needs instantiation"); return ApplicabilityResult::failure(sig->id(), FAIL_CANDIDATE_OTHER); } } diff --git a/tools/chapel-py/src/python-types.h b/tools/chapel-py/src/python-types.h index 78b3d9081e2c..ef1cb2dbbcec 100644 --- a/tools/chapel-py/src/python-types.h +++ b/tools/chapel-py/src/python-types.h @@ -92,7 +92,7 @@ template std::vector unwrapVector(ContextObject* CONTEXT, PyObject* vec) { std::vector toReturn(PyList_Size(vec)); for (ssize_t i = 0; i < PyList_Size(vec); i++) { - toReturn.push_back(PythonReturnTypeInfo::unwrap(CONTEXT, PyList_GetItem(vec, i))); + toReturn[i] = PythonReturnTypeInfo::unwrap(CONTEXT, PyList_GetItem(vec, i)); } return toReturn; } diff --git a/tools/chpl-language-server/src/chpl-language-server.py b/tools/chpl-language-server/src/chpl-language-server.py index f96260da9a82..b526143459da 100755 --- a/tools/chpl-language-server/src/chpl-language-server.py +++ b/tools/chpl-language-server/src/chpl-language-server.py @@ -499,6 +499,8 @@ def __init__(self, file: str, config: Optional["WorkspaceConfig"]): self.context: chapel.Context = chapel.Context() self.file_infos: List["FileInfo"] = [] self.global_uses: Dict[str, List[References]] = defaultdict(list) + self.instantiation_ids: Dict[chapel.TypedSignature, str] = {} + self.instantiation_id_counter = 0 if config: file_config = config.for_file(file) @@ -508,6 +510,27 @@ def __init__(self, file: str, config: Optional["WorkspaceConfig"]): self.context.set_module_paths(self.module_paths, self.file_paths) + def register_signature(self, sig: chapel.TypedSignature) -> str: + """ + The language server can't send over typed signatures directly for + situations such as call hierarchy items (but we need to reason about + instantiations). Instead, keep a global unique ID for each signature, + and use that to identify them. + """ + if sig in self.instantiation_ids: + return self.instantiation_ids[sig] + + self.instantiation_id_counter += 1 + uid = str(self.instantiation_id_counter) + self.instantiation_ids[sig] = uid + return uid + + def retrieve_signature(self, uid: str) -> Optional[chapel.TypedSignature]: + for sig, sig_uid in self.instantiation_ids.items(): + if sig_uid == uid: + return sig + return None + def new_file_info( self, uri: str, use_resolver: bool ) -> Tuple["FileInfo", List[Any]]: @@ -1007,7 +1030,9 @@ def __init__(self, config: CLSConfig): super().__init__("chpl-language-server", "v0.1") self.contexts: Dict[str, ContextContainer] = {} - self.file_infos: Dict[str, FileInfo] = {} + self.context_ids: Dict[ContextContainer, str] = {} + self.context_id_counter = 0 + self.file_infos: Dict[Tuple[str, Optional[str]], FileInfo] = {} self.configurations: Dict[str, WorkspaceConfig] = {} self.use_resolver: bool = config.get("resolver") @@ -1107,9 +1132,17 @@ def get_context(self, uri: str) -> ContextContainer: for file in context.file_paths: self.contexts[file] = context self.contexts[path] = context + self.context_id_counter += 1 + self.context_ids[context] = str(self.context_id_counter) return context + def retrieve_context(self, context_id: str) -> Optional[ContextContainer]: + for ctx, cid in self.context_ids.items(): + if cid == context_id: + return ctx + return None + def eagerly_process_all_files(self, context: ContextContainer): cfg = context.config if cfg: @@ -1117,7 +1150,10 @@ def eagerly_process_all_files(self, context: ContextContainer): self.get_file_info("file://" + file, do_update=False) def get_file_info( - self, uri: str, do_update: bool = False + self, + uri: str, + do_update: bool = False, + context_id: Optional[str] = None, ) -> Tuple[FileInfo, List[Any]]: """ The language server maintains a FileInfo object per file. The FileInfo @@ -1128,19 +1164,34 @@ def get_file_info( creating one if it doesn't exist. If do_update is set to True, then the FileInfo's index is rebuilt even if it has already been computed. This is useful if the underlying file has changed. + + Most of the time, we will create a new context for a given URI. When + requested, however, context_id will be used to create a FileInfo + for a specific context. This is useful if e.g., file A wants to display + an instantiation in file B. """ errors = [] - if uri in self.file_infos: - file_info = self.file_infos[uri] + fi_key = (uri, context_id) + if fi_key in self.file_infos: + file_info = self.file_infos[fi_key] if do_update: errors = file_info.context.advance() else: - file_info, errors = self.get_context(uri).new_file_info( - uri, self.use_resolver - ) - self.file_infos[uri] = file_info + if context_id: + context = self.retrieve_context(context_id) + assert context + else: + context = self.get_context(uri) + + file_info, errors = context.new_file_info(uri, self.use_resolver) + self.file_infos[fi_key] = file_info + + # Also make this the "default" context for this file in case we + # open it. + if (uri, None) not in self.file_infos: + self.file_infos[(uri, None)] = file_info # filter out errors that are not related to the file cur_path = uri[len("file://") :] @@ -1396,7 +1447,8 @@ def sym_to_call_hierarchy_item( """ loc = location_to_location(sym.location()) - inst_idx = -1 + inst_id = None + context_id = None return CallHierarchyItem( name=sym.name(), @@ -1405,11 +1457,11 @@ def sym_to_call_hierarchy_item( uri=loc.uri, range=loc.range, selection_range=location_to_range(sym.name_location()), - data=[sym.unique_id(), inst_idx], + data=[sym.unique_id(), inst_id, context_id], ) def fn_to_call_hierarchy_item( - self, sig: chapel.TypedSignature + self, sig: chapel.TypedSignature, caller_context: ContextContainer ) -> CallHierarchyItem: """ Like sym_to_call_hierarchy_item, but for function instantiations. @@ -1419,8 +1471,8 @@ def fn_to_call_hierarchy_item( """ fn: chapel.Function = sig.ast() item = self.sym_to_call_hierarchy_item(fn) - fi, _ = self.get_file_info(item.uri) - item.data[1] = fi.index_of_instantiation(fn, sig) + item.data[1] = caller_context.register_signature(sig) + item.data[2] = self.context_ids[caller_context] return item @@ -1433,16 +1485,17 @@ def unpack_call_hierarchy_item( item.data is None or not isinstance(item.data, list) or not isinstance(item.data[0], str) - or not isinstance(item.data[1], int) + or not isinstance(item.data[1], Optional[str]) + or not isinstance(item.data[2], Optional[str]) ): self.show_message( "Call hierarchy item contains missing or invalid additional data", MessageType.Error, ) return None - uid, idx = item.data + uid, inst_id, ctx = item.data - fi, _ = self.get_file_info(item.uri) + fi, _ = self.get_file_info(item.uri, context_id=ctx) # TODO: Performance: # Once the Python bindings supports it, we can use the @@ -1456,11 +1509,7 @@ def unpack_call_hierarchy_item( # We don't handle that here. return None - instantiation = None - if idx != -1: - instantiation = fi.instantiation_at_index(fn, idx) - else: - instantiation = fi.concrete_instantiation_for(fn) + instantiation = fi.context.retrieve_signature(inst_id) return (fi, fn, instantiation) @@ -2000,7 +2049,10 @@ async def prepare_call_hierarchy( # Oddly, returning multiple here makes for no child nodes in the VSCode # UI. Just take one signature for now. - return next(([ls.fn_to_call_hierarchy_item(sig)] for sig in sigs), []) + return next( + ([ls.fn_to_call_hierarchy_item(sig, fi.context)] for sig in sigs), + [], + ) @server.feature(CALL_HIERARCHY_INCOMING_CALLS) async def call_hierarchy_incoming( @@ -2046,7 +2098,7 @@ async def call_hierarchy_incoming( if isinstance(called_fn, str): item = ls.sym_to_call_hierarchy_item(hack_id_to_node[called_fn]) else: - item = ls.fn_to_call_hierarchy_item(called_fn) + item = ls.fn_to_call_hierarchy_item(called_fn, fi.context) to_return.append( CallHierarchyIncomingCall( @@ -2070,7 +2122,7 @@ async def call_hierarchy_outgoing( if unpacked is None: return None - _, fn, instantiation = unpacked + fi, fn, instantiation = unpacked outgoing_calls: Dict[chapel.TypedSignature, List[chapel.FnCall]] = ( defaultdict(list) @@ -2093,7 +2145,7 @@ async def call_hierarchy_outgoing( to_return = [] for called_fn, calls in outgoing_calls.items(): - item = ls.fn_to_call_hierarchy_item(called_fn) + item = ls.fn_to_call_hierarchy_item(called_fn, fi.context) to_return.append( CallHierarchyOutgoingCall( item, diff --git a/tools/chpl-language-server/test/basic.py b/tools/chpl-language-server/test/basic.py index 926168994618..5999bf35911f 100644 --- a/tools/chpl-language-server/test/basic.py +++ b/tools/chpl-language-server/test/basic.py @@ -140,6 +140,40 @@ async def test_go_to_definition_use_standard(client: LanguageClient): await check_goto_decl_def_module(client, doc, pos((2, 8)), mod_Time) +@pytest.mark.asyncio +async def test_go_to_definition_use_across_modules(client: LanguageClient): + """ + Ensure that go-to-definition works on symbols that reference other modules + """ + + fileA = """ + module A { + var x = 42; + } + """ + fileB = """ + module B { + use A; + var y = x; + } + """ + + async def check(docs): + docA = docs("A") + docB = docs("B") + + await check_goto_decl_def_module(client, docB, pos((1, 6)), docA) + await check_goto_decl_def( + client, docB, pos((2, 10)), (docA, pos((1, 6))) + ) + + async with source_files(client, A=fileA, B=fileB) as docs: + await check(docs) + + async with unrelated_source_files(client, A=fileA, B=fileB) as docs: + await check(docs) + + @pytest.mark.asyncio async def test_go_to_definition_standard_rename(client: LanguageClient): """ diff --git a/tools/chpl-language-server/test/call_hierarchy.py b/tools/chpl-language-server/test/call_hierarchy.py new file mode 100644 index 000000000000..0260e705e354 --- /dev/null +++ b/tools/chpl-language-server/test/call_hierarchy.py @@ -0,0 +1,214 @@ +""" +Test the call hierarchy feature, which computes calls between functions. +""" + +import sys + +from lsprotocol.types import ClientCapabilities +from lsprotocol.types import ( + CallHierarchyPrepareParams, + CallHierarchyOutgoingCallsParams, + CallHierarchyItem, +) +from lsprotocol.types import InitializeParams +import pytest +import pytest_lsp +import typing +from pytest_lsp import ClientServerConfig, LanguageClient + +from util.utils import * +from util.config import CLS_PATH + + +@pytest_lsp.fixture( + config=ClientServerConfig( + server_command=[ + sys.executable, + CLS_PATH(), + "--resolver", + ], + client_factory=get_base_client, + ) +) +async def client(lsp_client: LanguageClient): + # Setup + params = InitializeParams(capabilities=ClientCapabilities()) + await lsp_client.initialize_session(params) + + yield + + # Teardown + await lsp_client.shutdown_session() + + +class CallTree: + def __init__(self, item_id: str, children: typing.List["CallTree"]): + self.item_id = item_id + self.children = children + + +async def collect_call_tree( + client: LanguageClient, item: CallHierarchyItem, depth: int +) -> typing.Optional[CallTree]: + if depth <= 0: + return None + + assert isinstance(item.data, list) + assert len(item.data) == 3 + item_id = item.data[0] + + children = [] + outgoing = await client.call_hierarchy_outgoing_calls_async( + CallHierarchyOutgoingCallsParams(item) + ) + if outgoing is not None: + for outgoing_call in outgoing: + new_tree = await collect_call_tree( + client, outgoing_call.to, depth - 1 + ) + if new_tree is not None: + children.append(new_tree) + + return CallTree(item_id, children) + + +async def compute_call_hierarchy( + client: LanguageClient, + doc: TextDocumentIdentifier, + position: Position, + depth: int, +) -> typing.Optional[CallTree]: + items = await client.text_document_prepare_call_hierarchy_async( + CallHierarchyPrepareParams(text_document=doc, position=position) + ) + if items is None: + return None + + assert len(items) == 1 + return await collect_call_tree(client, items[0], depth) + + +def verify_call_hierarchy(tree: CallTree, expected: CallTree): + assert tree.item_id == expected.item_id + assert len(tree.children) == len(expected.children) + for i in range(len(tree.children)): + verify_call_hierarchy(tree.children[i], expected.children[i]) + + +async def check_call_hierarchy( + client: LanguageClient, + doc: TextDocumentIdentifier, + position: Position, + expected: CallTree, + depth: int = 10, +) -> typing.Optional[CallTree]: + items = await client.text_document_prepare_call_hierarchy_async( + CallHierarchyPrepareParams(text_document=doc, position=position) + ) + assert items is not None + assert len(items) == 1 + tree = await collect_call_tree(client, items[0], depth) + assert tree is not None + verify_call_hierarchy(tree, expected) + return tree + + +@pytest.mark.asyncio +async def test_call_hierarchy_basic(client: LanguageClient): + file = """ + proc foo() {} + proc bar() do foo(); + bar(); + """ + + async with source_file(client, file) as doc: + expect = CallTree("main.bar", [CallTree("main.foo", [])]) + await check_call_hierarchy(client, doc, pos((2, 0)), expect) + + +@pytest.mark.asyncio +async def test_call_hierarchy_overloads(client: LanguageClient): + file = """ + proc foo(arg: int) {} + proc foo(arg: bool) {} + foo(1); + foo(true); + """ + + async with source_file(client, file) as doc: + expect_int = CallTree("main.foo", []) + await check_call_hierarchy(client, doc, pos((2, 0)), expect_int) + expect_bool = CallTree("main.foo#1", []) + await check_call_hierarchy(client, doc, pos((3, 0)), expect_bool) + + +@pytest.mark.asyncio +async def test_call_hierarchy_recursive(client: LanguageClient): + file = """ + proc foo() do foo(); + foo(); + """ + + async with source_file(client, file) as doc: + expect = CallTree("main.foo", [CallTree("main.foo", [])]) + await check_call_hierarchy(client, doc, pos((1, 0)), expect, depth=2) + + +@pytest.mark.asyncio +async def test_call_hierarchy_across_files(client: LanguageClient): + fileA = """ + module A { + proc someImplementationDetail(arg: string) {} + } + """ + fileB = """ + module B { + use A; + + proc toString(x: int): string do return ""; + proc toString(x: real): string do return ""; + + proc doSomething(arg) { + someImplementationDetail(toString(arg)); + } + } + """ + fileC = """ + module C { + use B; + + doSomething(12); + doSomething(12.0); + } + """ + + expected_int = CallTree( + "B.doSomething", + [ + CallTree("A.someImplementationDetail", []), + CallTree("B.toString", []), + ], + ) + expected_real = CallTree( + "B.doSomething", + [ + CallTree("A.someImplementationDetail", []), + CallTree("B.toString#1", []), + ], + ) + + async def check(docs): + await check_call_hierarchy(client, docs("C"), pos((3, 2)), expected_int) + await check_call_hierarchy( + client, docs("C"), pos((4, 2)), expected_real + ) + + # Ensure that call hierarchy works without .cls-commands.json... + async with unrelated_source_files( + client, A=fileA, B=fileB, C=fileC + ) as docs: + await check(docs) + + # ...and with .cls-commands.json + async with source_files(client, A=fileA, B=fileB, C=fileC) as docs: + await check(docs) diff --git a/tools/chpl-language-server/test/util/utils.py b/tools/chpl-language-server/test/util/utils.py index dcfd48cf1129..4717f283ebde 100644 --- a/tools/chpl-language-server/test/util/utils.py +++ b/tools/chpl-language-server/test/util/utils.py @@ -73,7 +73,12 @@ def on_semantic_token_refresh(params): class SourceFilesContext: - def __init__(self, client: LanguageClient, files: typing.Dict[str, str]): + def __init__( + self, + client: LanguageClient, + files: typing.Dict[str, str], + build_cls_commands: bool = True, + ): self.tempdir = tempfile.TemporaryDirectory() self.client = client @@ -96,8 +101,9 @@ def __init__(self, client: LanguageClient, files: typing.Dict[str, str]): commands[filepath] = [{"module_dirs": [], "files": allfiles}] commandspath = os.path.join(self.tempdir.name, ".cls-commands.json") - with open(commandspath, "w") as f: - json.dump(commands, f) + if build_cls_commands: + with open(commandspath, "w") as f: + json.dump(commands, f) def _get_doc(self, name: str) -> TextDocumentIdentifier: return TextDocumentIdentifier( @@ -159,6 +165,14 @@ def source_files(client: LanguageClient, **files: str): return SourceFilesContext(client, files) +def unrelated_source_files(client: LanguageClient, **files: str): + """ + Same as 'source_files', but doesn't create a .cls-commands.json file that + would cause the files to be treated as "connected" and resolved together. + """ + return SourceFilesContext(client, files, build_cls_commands=False) + + def source_file( client: LanguageClient, contents: str,