rt: Maintain stack ptrs correctly when returning from stack switches

This commit is contained in:
Brian Anderson 2012-02-13 23:18:21 -08:00
parent 214cdd0dee
commit a393fb3221
3 changed files with 64 additions and 0 deletions

@ -231,6 +231,7 @@ rust_task::call_on_c_stack(void *args, void *fn_ptr) {
// Too expensive to check
// I(thread, on_rust_stack());
uintptr_t prev_rust_sp = next_rust_sp;
next_rust_sp = get_sp();
bool borrowed_a_c_stack = false;
@ -251,6 +252,8 @@ rust_task::call_on_c_stack(void *args, void *fn_ptr) {
if (borrowed_a_c_stack) {
return_c_stack();
}
next_rust_sp = prev_rust_sp;
}
inline void
@ -259,11 +262,14 @@ rust_task::call_on_rust_stack(void *args, void *fn_ptr) {
// I(thread, !on_rust_stack());
I(thread, next_rust_sp);
uintptr_t prev_c_sp = next_c_sp;
next_c_sp = get_sp();
uintptr_t sp = sanitize_next_sp(next_rust_sp);
__morestack(args, fn_ptr, sp);
next_c_sp = prev_c_sp;
}
inline void

@ -0,0 +1,27 @@
native mod rustrt {
fn rust_dbg_call(cb: *u8,
data: ctypes::uintptr_t) -> ctypes::uintptr_t;
}
crust fn cb(data: ctypes::uintptr_t) -> ctypes::uintptr_t {
if data == 1u {
data
} else {
count(data - 1u) + 1u
}
}
fn count(n: uint) -> uint {
#debug("n = %?", n);
rustrt::rust_dbg_call(cb, n)
}
fn main() {
// Make sure we're on a task with small Rust stacks (main currently
// has a large stack)
task::spawn {||
let result = count(1000u);
#debug("result = %?", result);
assert result == 1000u;
};
}

@ -0,0 +1,31 @@
// This time we're testing repeatedly going up and down both stacks to
// make sure the stack pointers are maintained properly in both
// directions
native mod rustrt {
fn rust_dbg_call(cb: *u8,
data: ctypes::uintptr_t) -> ctypes::uintptr_t;
}
crust fn cb(data: ctypes::uintptr_t) -> ctypes::uintptr_t {
if data == 1u {
data
} else {
count(data - 1u) + count(data - 1u)
}
}
fn count(n: uint) -> uint {
#debug("n = %?", n);
rustrt::rust_dbg_call(cb, n)
}
fn main() {
// Make sure we're on a task with small Rust stacks (main currently
// has a large stack)
task::spawn {||
let result = count(12u);
#debug("result = %?", result);
assert result == 2048u;
};
}