From e6908a66ecabaca6e7adf4e847afe228f5d9f257 Mon Sep 17 00:00:00 2001 From: Patrick Walton Date: Thu, 16 Dec 2010 11:11:48 -0800 Subject: [PATCH] rustc: Infer the types of type-parametric functions --- src/comp/middle/typeck.rs | 76 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/src/comp/middle/typeck.rs b/src/comp/middle/typeck.rs index 9837ffa1add..332c893e841 100644 --- a/src/comp/middle/typeck.rs +++ b/src/comp/middle/typeck.rs @@ -259,6 +259,79 @@ fn ty_to_str(&@ty typ) -> str { ret s; } +// Replaces parameter types inside a type with type variables. +fn generalize_ty(@crate_ctxt cx, @ty t) -> @ty { + fn rewrap(@ty orig, &sty new) -> @ty { + ret @rec(struct=new, mut=orig.mut, cname=orig.cname); + } + + fn recur(@crate_ctxt cx, @ty t, + &hashmap[ast.def_id,@ty] ty_params_to_ty_vars) -> @ty { + alt (t.struct) { + case (ty_box(?subty)) { + auto new_subty = recur(cx, subty, ty_params_to_ty_vars); + ret rewrap(t, ty_box(new_subty)); + } + case (ty_vec(?subty)) { + auto new_subty = recur(cx, subty, ty_params_to_ty_vars); + ret rewrap(t, ty_vec(new_subty)); + } + case (ty_tup(?subtys)) { + let vec[@ty] new_subtys = vec(); + for (@ty subty in subtys) { + new_subtys += vec(recur(cx, subty, ty_params_to_ty_vars)); + } + ret rewrap(t, ty_tup(new_subtys)); + } + case (ty_rec(?fields)) { + let vec[field] new_fields = vec(); + for (field fld in fields) { + auto new_ty = recur(cx, fld.ty, ty_params_to_ty_vars); + new_fields += vec(rec(ident=fld.ident, ty=new_ty)); + } + ret rewrap(t, ty_rec(new_fields)); + } + case (ty_fn(?args, ?ret_ty)) { + let vec[arg] new_args = vec(); + for (arg a in args) { + auto new_ty = recur(cx, a.ty, ty_params_to_ty_vars); + new_args += vec(rec(mode=a.mode, ty=new_ty)); + } + auto new_ret_ty = recur(cx, ret_ty, ty_params_to_ty_vars); + ret rewrap(t, ty_fn(new_args, new_ret_ty)); + } + case (ty_obj(?methods)) { + let vec[method] new_methods = vec(); + for (method m in methods) { + let vec[arg] new_args = vec(); + for (arg a in m.inputs) { + auto new_ty = recur(cx, a.ty, ty_params_to_ty_vars); + new_args += vec(rec(mode=a.mode, ty=new_ty)); + } + auto new_rty = recur(cx, m.output, ty_params_to_ty_vars); + new_methods += vec(rec(ident=m.ident, inputs=new_args, + output=new_rty)); + } + ret rewrap(t, ty_obj(new_methods)); + } + case (ty_param(?pid)) { + if (ty_params_to_ty_vars.contains_key(pid)) { + ret ty_params_to_ty_vars.get(pid); + } + auto var_ty = next_ty_var(cx); + ty_params_to_ty_vars.insert(pid, var_ty); + ret var_ty; + } + case (_) { /* fall through */ } + } + + ret t; + } + + auto ty_params_to_ty_vars = common.new_def_hash[@ty](); + ret recur(cx, t, ty_params_to_ty_vars); +} + // Parses the programmer's textual representation of a type into our internal // notion of a type. `getter` is a function that returns the type // corresponding to a definition ID. @@ -1720,6 +1793,9 @@ fn check_expr(&fn_ctxt fcx, @ast.expr expr) -> @ast.expr { fail; } } + + t = generalize_ty(fcx.ccx, t); + ret @fold.respan[ast.expr_](expr.span, ast.expr_name(name, defopt, ast.ann_type(t)));