Skip to content

Commit

Permalink
feat: infer associate functions (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra authored Nov 30, 2024
1 parent 6537d40 commit 0acea1c
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 14 deletions.
2 changes: 1 addition & 1 deletion crates/mun_hir/src/method_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl InherentImpls {
self.map.values().flatten().copied()
}

// Returns all implementations defined for the specified type.
/// Returns all implementations defined for the specified type.
pub fn for_self_ty(&self, self_ty: &Ty) -> &[ImplId] {
match self_ty.interned() {
TyKind::Struct(s) => self.map.get(&s.id).map_or(&[], AsRef::as_ref),
Expand Down
84 changes: 71 additions & 13 deletions crates/mun_hir/src/ty/infer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{ops::Index, sync::Arc};

use la_arena::ArenaMap;
use mun_hir_input::ModuleId;
use rustc_hash::FxHashSet;

use crate::{
Expand All @@ -24,8 +25,10 @@ mod unify;

use crate::{
expr::{LiteralFloat, LiteralFloatKind, LiteralInt, LiteralIntKind},
has_module::HasModule,
ids::DefWithBodyId,
resolve::{resolver_for_expr, HasResolver},
method_resolution::lookup_method,
resolve::{resolver_for_expr, HasResolver, ResolveValueResult},
ty::{
primitives::{FloatTy, IntTy},
TyKind,
Expand Down Expand Up @@ -184,6 +187,13 @@ impl<'a> InferenceResultBuilder<'a> {
}
}

/// Returns the module in which the body is defined.
pub fn module(&self) -> ModuleId {
match self.body.owner() {
DefWithBodyId::FunctionId(func) => func.module(self.db.upcast()),
}
}

/// Associate the given `ExprId` with the specified `Ty`.
fn set_expr_type(&mut self, expr: ExprId, ty: Ty) {
self.type_of_expr.insert(expr, ty);
Expand Down Expand Up @@ -718,25 +728,73 @@ impl<'a> InferenceResultBuilder<'a> {
}
}

fn resolve_assoc_item(
&mut self,
def: TypeNs,
path: &Path,
remaining_index: usize,
id: ExprId,
) -> Option<ValueNs> {
// We can only resolve the last element of the path.
let name = if remaining_index == path.segments.len() - 1 {
&path.segments[remaining_index]
} else {
return None;
};

// Infer the type of the definitions
let type_for_def_fn = |def| self.db.type_for_def(def, Namespace::Types);
let root_ty = match def {
TypeNs::SelfType(id) => self.db.type_for_impl_self(id),
TypeNs::StructId(id) => type_for_def_fn(TypableDef::Struct(id.into())),
TypeNs::TypeAliasId(id) => type_for_def_fn(TypableDef::TypeAlias(id.into())),
TypeNs::PrimitiveType(id) => type_for_def_fn(TypableDef::PrimitiveType(id)),
};

// Resolve the value.
let function_id = match lookup_method(self.db, &root_ty, self.module(), name) {
Ok(value) => value,
Err(Some(value)) => {
self.diagnostics
.push(InferenceDiagnostic::PathIsPrivate { id });
value
}
_ => return None,
};

Some(ValueNs::FunctionId(function_id))
}

fn resolve_value_path_inner(
&mut self,
resolver: &Resolver,
path: &Path,
id: ExprId,
) -> Option<ValueNs> {
let value_or_partial = resolver.resolve_path_as_value(self.db.upcast(), path)?;
match value_or_partial {
ResolveValueResult::ValueNs(it, vis) => {
if !vis.is_visible_from(self.db, self.module()) {
self.diagnostics
.push(diagnostics::InferenceDiagnostic::PathIsPrivate { id });
}

Some(it)
}
ResolveValueResult::Partial(def, remaining_index) => {
self.resolve_assoc_item(def, path, remaining_index, id)
}
}
}

fn infer_path_expr(
&mut self,
resolver: &Resolver,
path: &Path,
id: ExprId,
check_params: &CheckParams,
) -> Option<Ty> {
if let Some((value, vis)) = resolver.resolve_path_as_value_fully(self.db.upcast(), path) {
// Check visibility of this item
if !vis.is_visible_from(
self.db,
self.resolver
.module()
.expect("resolver must have a module to be able to resolve modules"),
) {
self.diagnostics
.push(diagnostics::InferenceDiagnostic::PathIsPrivate { id });
}

if let Some(value) = self.resolve_value_path_inner(resolver, path, id) {
// Match based on what type of value we found
match value {
ValueNs::ImplSelf(i) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: crates/mun_hir/src/ty/tests.rs
expression: "infer(r#\"\n //- /foo.mun\n pub struct Foo {\n a: i32\n }\n\n impl Foo {\n fn new(){}\n }\n\n //- /mod.mun\n fn main() {\n foo::Foo::new();\n }\n \"#)"
---
16..29: access of private type
10..34 '{ ...w(); }': ()
16..29 'foo::Foo::new': function new() -> ()
16..31 'foo::Foo::new()': ()
54..56 '{}': ()
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
source: crates/mun_hir/src/ty/tests.rs
expression: "infer(r#\"\n struct Foo {\n a: i32\n }\n\n impl Foo {\n fn new() -> Self {\n Self { a: 3 }\n }\n }\n\n fn main() {\n let a = Foo::new();\n }\n \"#)"
---
102..129 '{ ...w(); }': ()
112..113 'a': Foo
116..124 'Foo::new': function new() -> Foo
116..126 'Foo::new()': Foo
59..88 '{ ... }': Foo
69..82 'Self { a: 3 }': Foo
79..80 '3': i32
42 changes: 42 additions & 0 deletions crates/mun_hir/src/ty/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,48 @@ fn infer_self_field() {
));
}

#[test]
fn infer_assoc_function() {
insta::assert_snapshot!(infer(
r#"
struct Foo {
a: i32
}
impl Foo {
fn new() -> Self {
Self { a: 3 }
}
}
fn main() {
let a = Foo::new();
}
"#
));
}

#[test]
fn infer_access_hidden_assoc_function() {
insta::assert_snapshot!(infer(
r#"
//- /foo.mun
pub struct Foo {
a: i32
}
impl Foo {
fn new(){}
}
//- /mod.mun
fn main() {
foo::Foo::new();
}
"#
));
}

#[test]
fn infer_basics() {
insta::assert_snapshot!(infer(
Expand Down

0 comments on commit 0acea1c

Please sign in to comment.