Add simple syntax extension (#simplext)

This commit is contained in:
Paul Stansifer 2011-06-20 17:26:17 -07:00 committed by Graydon Hoare
parent b632681780
commit c3901cdf8e
15 changed files with 290 additions and 60 deletions

View File

@ -1,4 +1,4 @@
import std::vec;
import std::option;
import std::map::hashmap;
import driver::session::session;
@ -6,20 +6,24 @@ import front::parser::parser;
import util::common::span;
import util::common::new_str_hash;
type syntax_expander =
fn(&ext_ctxt, span, &vec[@ast::expr], option::t[str]) -> @ast::expr ;
// Temporary: to introduce a tag in order to make a recursive type work
tag syntax_extension { x(syntax_expander); }
type syntax_expander =
fn(&ext_ctxt, span, &vec[@ast::expr], option::t[str]) -> @ast::expr;
type macro_definer = fn(&ext_ctxt, span, &vec[@ast::expr],
option::t[str]) -> tup(str, syntax_extension);
tag syntax_extension {
normal(syntax_expander);
macro_defining(macro_definer);
}
// A temporary hard-coded map of methods for expanding syntax extension
// AST nodes into full ASTs
fn syntax_expander_table() -> hashmap[str, syntax_extension] {
auto syntax_expanders = new_str_hash[syntax_extension]();
syntax_expanders.insert("fmt", x(extfmt::expand_syntax_ext));
syntax_expanders.insert("env", x(extenv::expand_syntax_ext));
syntax_expanders.insert("fmt", normal(extfmt::expand_syntax_ext));
syntax_expanders.insert("env", normal(extenv::expand_syntax_ext));
syntax_expanders.insert("simplext",
macro_defining(extsimplext::add_new_extension));
ret syntax_expanders;
}
@ -51,6 +55,37 @@ fn mk_ctxt(parser parser) -> ext_ctxt {
span_unimpl=ext_span_unimpl,
next_id=ext_next_id);
}
fn expr_to_str(&ext_ctxt cx, @ast::expr expr, str error) -> str {
alt (expr.node) {
case (ast::expr_lit(?l)) {
alt (l.node) {
case (ast::lit_str(?s, _)) { ret s; }
case (_) { cx.span_fatal(l.span, error); }
}
}
case (_) { cx.span_fatal(expr.span, error); }
}
}
fn expr_to_ident(&ext_ctxt cx, @ast::expr expr, str error) -> ast::ident {
alt(expr.node) {
case (ast::expr_path(?p)) {
if (vec::len(p.node.types) > 0u
|| vec::len(p.node.idents) != 1u) {
cx.span_fatal(expr.span, error);
} else {
ret p.node.idents.(0);
}
}
case (_) {
cx.span_fatal(expr.span, error);
}
}
}
//
// Local Variables:
// mode: rust

View File

@ -21,27 +21,13 @@ fn expand_syntax_ext(&ext_ctxt cx, common::span sp, &vec[@ast::expr] args,
// FIXME: if this was more thorough it would manufacture an
// option::t[str] rather than just an maybe-empty string.
auto var = expr_to_str(cx, args.(0));
auto var = expr_to_str(cx, args.(0), "#env requires a string");
alt (generic_os::getenv(var)) {
case (option::none) { ret make_new_str(cx, sp, ""); }
case (option::some(?s)) { ret make_new_str(cx, sp, s); }
}
}
// FIXME: duplicate code copied from extfmt:
fn expr_to_str(&ext_ctxt cx, @ast::expr expr) -> str {
alt (expr.node) {
case (ast::expr_lit(?l)) {
alt (l.node) {
case (ast::lit_str(?s, _)) { ret s; }
case (_) { cx.span_fatal(l.span, "malformed #env call"); }
}
}
case (_) { cx.span_fatal(expr.span, "malformed #env call"); }
}
}
fn make_new_lit(&ext_ctxt cx, common::span sp, ast::lit_ lit) -> @ast::expr {
auto sp_lit = @rec(node=lit, span=sp);
ret @rec(id=cx.next_id(), node=ast::expr_lit(sp_lit), span=sp);

View File

@ -20,7 +20,8 @@ fn expand_syntax_ext(&ext_ctxt cx, common::span sp, &vec[@ast::expr] args,
if (vec::len[@ast::expr](args) == 0u) {
cx.span_fatal(sp, "#fmt requires a format string");
}
auto fmt = expr_to_str(cx, args.(0));
auto fmt = expr_to_str(cx, args.(0), "first argument to #fmt must be a "
+ "string literal.");
auto fmtspan = args.(0).span;
log "Format string:";
log fmt;
@ -32,20 +33,6 @@ fn expand_syntax_ext(&ext_ctxt cx, common::span sp, &vec[@ast::expr] args,
ret pieces_to_expr(cx, sp, pieces, args);
}
fn expr_to_str(&ext_ctxt cx, @ast::expr expr) -> str {
auto err_msg = "first argument to #fmt must be a string literal";
alt (expr.node) {
case (ast::expr_lit(?l)) {
alt (l.node) {
case (ast::lit_str(?s, _)) { ret s; }
case (_) { cx.span_fatal(l.span, err_msg); }
}
}
case (_) { cx.span_fatal(expr.span, err_msg); }
}
}
// FIXME: A lot of these functions for producing expressions can probably
// be factored out in common with other code that builds expressions.
// FIXME: Cleanup the naming of these functions

View File

@ -0,0 +1,144 @@
use std;
import util::common::span;
import std::vec;
import std::option;
import vec::map;
import vec::len;
import option::some;
import option::none;
import ext::syntax_extension;
import ext::ext_ctxt;
import ext::normal;
import ext::expr_to_str;
import ext::expr_to_ident;
import fold::*;
import ast::ident;
import ast::path_;
import ast::expr_path;
export add_new_extension;
//temporary, until 'position' shows up in the snapshot
fn position[T](&T x, &vec[T] v) -> option::t[uint] {
let uint i = 0u;
while (i < len(v)) {
if (x == v.(i)) { ret some[uint](i); }
i += 1u;
}
ret none[uint];
}
// substitute, in a position that's required to be an ident
fn subst_ident(&ext_ctxt cx, &vec[@ast::expr] args,
@vec[ident] param_names, &ident i, ast_fold fld) -> ident {
alt (position(i, *param_names)) {
case (some[uint](?idx)) {
ret expr_to_ident(cx, args.(idx),
"This argument is expanded as an "
+ "identifier; it must be one.");
}
case (none[uint]) {
ret i;
}
}
}
fn subst_path(&ext_ctxt cx, &vec[@ast::expr] args,
@vec[ident] param_names, &path_ p, ast_fold fld) -> path_ {
// Don't substitute into qualified names.
if (len(p.types) > 0u || len(p.idents) != 1u) { ret p; }
alt (position(p.idents.(0), *param_names)) {
case (some[uint](?idx)) {
alt (args.(idx).node) {
case (expr_path(?new_path)) {
ret new_path.node;
}
case (_) {
cx.span_fatal(args.(idx).span,
"This argument is expanded as a path; "
+ "it must be one.");
}
}
}
case (none[uint]) { ret p; }
}
}
fn subst_expr(&ext_ctxt cx, &vec[@ast::expr] args, @vec[ident] param_names,
&ast::expr_ e, ast_fold fld,
fn(&ast::expr_, ast_fold) -> ast::expr_ orig) -> ast::expr_ {
ret alt(e) {
case (expr_path(?p)){
// Don't substitute into qualified names.
if (len(p.node.types) > 0u || len(p.node.idents) != 1u) { e }
alt (position(p.node.idents.(0), *param_names)) {
case (some[uint](?idx)) {
args.(idx).node
}
case (none[uint]) { e }
}
}
case (_) { orig(e,fld) }
}
}
fn add_new_extension(&ext_ctxt cx, span sp, &vec[@ast::expr] args,
option::t[str] body) -> tup(str, syntax_extension) {
if (len(args) < 2u) {
cx.span_fatal(sp, "malformed extension description");
}
fn generic_extension(&ext_ctxt cx, span sp, &vec[@ast::expr] args,
option::t[str] body, @vec[ident] param_names,
@ast::expr dest_form) -> @ast::expr {
if (len(args) != len(*param_names)) {
cx.span_fatal(sp, #fmt("extension expects %u arguments, got %u",
len(*param_names), len(args)));
}
auto afp = default_ast_fold();
auto f_pre =
rec(fold_ident = bind subst_ident(cx, args, param_names, _, _),
fold_path = bind subst_path(cx, args, param_names, _, _),
fold_expr = bind subst_expr(cx, args, param_names, _, _,
afp.fold_expr)
with *afp);
auto f = make_fold(f_pre);
auto result = f.fold_expr(dest_form);
dummy_out(f); //temporary: kill circular reference
ret result;
}
let vec[ident] param_names = vec::empty[ident]();
let uint idx = 1u;
while(1u+idx < len(args)) {
param_names +=
[expr_to_ident(cx, args.(idx),
"this parameter name must be an identifier.")];
idx += 1u;
}
ret tup(expr_to_str(cx, args.(0), "first arg must be a literal string."),
normal(bind generic_extension(_,_,_,_,@param_names,
args.(len(args)-1u))));
}
//
// Local Variables:
// mode: rust
// fill-column: 78;
// indent-tabs-mode: nil
// c-basic-offset: 4
// buffer-file-coding-system: utf-8-unix
// compile-command: "make -k -C $RBUILD 2>&1 | sed -e 's/\\/x\\//x:\\//g'";
// End:
//

View File

@ -963,10 +963,19 @@ fn expand_syntax_ext(&parser p, common::span sp, &ast::path path,
auto extname = path.node.idents.(0);
alt (p.get_syntax_expanders().find(extname)) {
case (none) { p.fatal("unknown syntax expander: '" + extname + "'"); }
case (some(ext::x(?ext))) {
case (some(ext::normal(?ext))) {
auto ext_cx = ext::mk_ctxt(p);
ret ast::expr_ext(path, args, body, ext(ext_cx, sp, args, body));
}
// because we have expansion inside parsing, new macros are only
// visible further down the file
case (some(ext::macro_defining(?ext))) {
auto ext_cx = ext::mk_ctxt(p);
auto name_and_extension = ext(ext_cx, sp, args, body);
p.get_syntax_expanders().insert(name_and_extension._0,
name_and_extension._1);
ret ast::expr_tup(vec::empty[ast::elt]());
}
}
}

View File

@ -597,8 +597,6 @@ fn lookup_in_scope(&env e, scopes sc, &span sp, &ident name, namespace ns) ->
option::t[def] {
fn in_scope(&env e, &span sp, &ident name, &scope s, namespace ns) ->
option::t[def] {
//not recursing through globs
alt (s) {
case (scope_crate(?c)) {
ret lookup_in_local_mod(e, -1, sp, name, ns, inside);

View File

@ -48,6 +48,7 @@ mod front {
mod ext;
mod extfmt;
mod extenv;
mod extsimplext;
mod fold;
mod codemap;
mod lexer;

View File

@ -228,6 +228,24 @@ fn find[T](fn(&T) -> bool f, &vec[T] v) -> option::t[T] {
ret none[T];
}
fn position[T](&T x, &array[T] v) -> option::t[uint] {
let uint i = 0u;
while (i < len(v)) {
if (x == v.(i)) { ret some[uint](i); }
i += 1u;
}
ret none[uint];
}
fn position_pred[T](fn (&T) -> bool f, &vec[T] v) -> option::t[uint] {
let uint i = 0u;
while (i < len(v)) {
if (f(v.(i))) { ret some[uint](i); }
i += 1u;
}
ret none[uint];
}
fn member[T](&T x, &array[T] v) -> bool {
for (T elt in v) { if (x == elt) { ret true; } }
ret false;

View File

@ -1,4 +1,4 @@
// error-pattern:malformed #env call
// error-pattern:requires a string
fn main() {
#env(10);

View File

@ -0,0 +1,6 @@
//error-pattern:expanded as an identifier
fn main() {
#simplext("mylambda", x, body, {fn f(int x) -> int {ret body}; f});
assert(#mylambda(y*1, y*2)(8) == 16);
}

View File

@ -0,0 +1,7 @@
//error-pattern:expects 0 arguments, got 16
fn main() {
#simplext("trivial", 1*2*4*2*1);
assert(#trivial(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16) == 16);
}

View File

@ -1,9 +1,12 @@
use std;
import std::vec::*;
import std::option;
fn test_init_elt() {
let vec[uint] v = std::vec::init_elt[uint](5u, 3u);
assert (std::vec::len[uint](v) == 3u);
let vec[uint] v = init_elt[uint](5u, 3u);
assert (len[uint](v) == 3u);
assert (v.(0) == 5u);
assert (v.(1) == 5u);
assert (v.(2) == 5u);
@ -13,8 +16,8 @@ fn id(uint x) -> uint { ret x; }
fn test_init_fn() {
let fn(uint) -> uint op = id;
let vec[uint] v = std::vec::init_fn[uint](op, 5u);
assert (std::vec::len[uint](v) == 5u);
let vec[uint] v = init_fn[uint](op, 5u);
assert (len[uint](v) == 5u);
assert (v.(0) == 0u);
assert (v.(1) == 1u);
assert (v.(2) == 2u);
@ -24,17 +27,17 @@ fn test_init_fn() {
fn test_slice() {
let vec[int] v = [1, 2, 3, 4, 5];
auto v2 = std::vec::slice[int](v, 2u, 4u);
assert (std::vec::len[int](v2) == 2u);
auto v2 = slice[int](v, 2u, 4u);
assert (len[int](v2) == 2u);
assert (v2.(0) == 3);
assert (v2.(1) == 4);
}
fn test_map() {
fn square(&int x) -> int { ret x * x; }
let std::option::operator[int, int] op = square;
let option::operator[int, int] op = square;
let vec[int] v = [1, 2, 3, 4, 5];
let vec[int] s = std::vec::map[int, int](op, v);
let vec[int] s = map[int, int](op, v);
let int i = 0;
while (i < 5) { assert (v.(i) * v.(i) == s.(i)); i += 1; }
}
@ -44,16 +47,16 @@ fn test_map2() {
auto f = times;
auto v0 = [1, 2, 3, 4, 5];
auto v1 = [5, 4, 3, 2, 1];
auto u = std::vec::map2[int, int, int](f, v0, v1);
auto u = map2[int, int, int](f, v0, v1);
auto i = 0;
while (i < 5) { assert (v0.(i) * v1.(i) == u.(i)); i += 1; }
}
fn test_filter_map() {
fn halve(&int i) -> std::option::t[int] {
fn halve(&int i) -> option::t[int] {
if (i % 2 == 0) {
ret std::option::some[int](i / 2);
} else { ret std::option::none[int]; }
ret option::some[int](i / 2);
} else { ret option::none[int]; }
}
fn halve_for_sure(&int i) -> int { ret i / 2; }
let vec[int] all_even = [0, 2, 8, 6];
@ -61,11 +64,31 @@ fn test_filter_map() {
let vec[int] all_odd2 = [];
let vec[int] mix = [9, 2, 6, 7, 1, 0, 0, 3];
let vec[int] mix_dest = [1, 3, 0, 0];
assert (std::vec::filter_map(halve, all_even) ==
std::vec::map(halve_for_sure, all_even));
assert (std::vec::filter_map(halve, all_odd1) == std::vec::empty[int]());
assert (std::vec::filter_map(halve, all_odd2) == std::vec::empty[int]());
assert (std::vec::filter_map(halve, mix) == mix_dest);
assert (filter_map(halve, all_even) ==
map(halve_for_sure, all_even));
assert (filter_map(halve, all_odd1) == empty[int]());
assert (filter_map(halve, all_odd2) == empty[int]());
assert (filter_map(halve, mix) == mix_dest);
}
fn test_position() {
let vec[int] v1 = [1, 2, 3, 3, 2, 5];
assert (position(1, v1) == option::some[uint](0u));
assert (position(2, v1) == option::some[uint](1u));
assert (position(5, v1) == option::some[uint](5u));
assert (position(4, v1) == option::none[uint]);
}
fn test_position_pred() {
fn less_than_three(&int i) -> bool {
ret i <3;
}
fn is_eighteen(&int i) -> bool {
ret i == 18;
}
let vec[int] v1 = [5, 4, 3, 2, 1];
assert (position_pred(less_than_three, v1) == option::some[uint](3u));
assert (position_pred(is_eighteen, v1) == option::none[uint]);
}
fn main() {
@ -75,4 +98,6 @@ fn main() {
test_map();
test_map2();
test_filter_map();
test_position();
test_position_pred();
}

View File

@ -0,0 +1,5 @@
fn main() {
#simplext("mylambda", x, body, {fn f(int x) -> int {ret body}; f});
assert(#mylambda(y,y*2)(8) == 16);
}

View File

@ -0,0 +1,5 @@
fn main() {
#simplext("trivial", 1*2*4*2*1);
assert(#trivial() == 16);
}

View File

@ -0,0 +1,4 @@
fn main() {
#simplext("m1", a, a*4);
assert (#m1(2) == 8);
}