Final 200 lines of code
1
#![feature(naked_functions)]
2
use std::arch::asm;
3
4
const DEFAULT_STACK_SIZE: usize = 1024 * 1024 * 2;
5
const MAX_THREADS: usize = 4;
6
static mut RUNTIME: usize = 0;
7
8
pub struct Runtime {
9
threads: Vec<Thread>,
10
current: usize,
11
}
12
13
#[derive(PartialEq, Eq, Debug)]
14
enum State {
15
Available,
16
Running,
17
Ready,
18
}
19
20
struct Thread {
21
id: usize,
22
stack: Vec<u8>,
23
ctx: ThreadContext,
24
state: State,
25
}
26
27
#[derive(Debug, Default)]
28
#[repr(C)]
29
struct ThreadContext {
30
rsp: u64,
31
r15: u64,
32
r14: u64,
33
r13: u64,
34
r12: u64,
35
rbx: u64,
36
rbp: u64,
37
}
38
39
impl Thread {
40
fn new(id: usize) -> Self {
41
Thread {
42
id,
43
stack: vec![0_u8; DEFAULT_STACK_SIZE],
44
ctx: ThreadContext::default(),
45
state: State::Available,
46
}
47
}
48
}
49
50
impl Runtime {
51
pub fn new() -> Self {
52
let base_thread = Thread {
53
id: 0,
54
stack: vec![0_u8; DEFAULT_STACK_SIZE],
55
ctx: ThreadContext::default(),
56
state: State::Running,
57
};
58
59
let mut threads = vec![base_thread];
60
let mut available_threads: Vec<Thread> = (1..MAX_THREADS).map(|i| Thread::new(i)).collect();
61
threads.append(&mut available_threads);
62
63
Runtime {
64
threads,
65
current: 0,
66
}
67
}
68
69
pub fn init(&self) {
70
unsafe {
71
let r_ptr: *const Runtime = self;
72
RUNTIME = r_ptr as usize;
73
}
74
}
75
76
pub fn run(&mut self) -> ! {
77
while self.t_yield() {}
78
std::process::exit(0);
79
}
80
81
fn t_return(&mut self) {
82
if self.current != 0 {
83
self.threads[self.current].state = State::Available;
84
self.t_yield();
85
}
86
}
87
88
#[inline(never)]
89
fn t_yield(&mut self) -> bool {
90
let mut pos = self.current;
91
while self.threads[pos].state != State::Ready {
92
pos += 1;
93
if pos == self.threads.len() {
94
pos = 0;
95
}
96
if pos == self.current {
97
return false;
98
}
99
}
100
101
if self.threads[self.current].state != State::Available {
102
self.threads[self.current].state = State::Ready;
103
}
104
105
self.threads[pos].state = State::Running;
106
let old_pos = self.current;
107
self.current = pos;
108
109
unsafe {
110
let old: *mut ThreadContext = &mut self.threads[old_pos].ctx;
111
let new: *const ThreadContext = &self.threads[pos].ctx;
112
asm!("call switch", in("rdi") old, in("rsi") new, clobber_abi("C"));
113
}
114
self.threads.len() > 0
115
}
116
117
pub fn spawn(&mut self, f: fn()) {
118
let available = self
119
.threads
120
.iter_mut()
121
.find(|t| t.state == State::Available)
122
.expect("no available thread.");
123
124
let size = available.stack.len();
125
unsafe {
126
let s_ptr = available.stack.as_mut_ptr().offset(size as isize);
127
let s_ptr = (s_ptr as usize & !15) as *mut u8;
128
std::ptr::write(s_ptr.offset(-16) as *mut u64, guard as u64);
129
std::ptr::write(s_ptr.offset(-24) as *mut u64, skip as u64);
130
std::ptr::write(s_ptr.offset(-32) as *mut u64, f as u64);
131
available.ctx.rsp = s_ptr.offset(-32) as u64;
132
}
133
available.state = State::Ready;
134
}
135
}
136
137
#[naked]
138
unsafe extern "C" fn skip() {
139
asm!("ret", options(noreturn))
140
}
141
142
fn guard() {
143
unsafe {
144
let rt_ptr = RUNTIME as *mut Runtime;
145
(*rt_ptr).t_return();
146
};
147
}
148
149
pub fn yield_thread() {
150
unsafe {
151
let rt_ptr = RUNTIME as *mut Runtime;
152
(*rt_ptr).t_yield();
153
};
154
}
155
156
#[naked]
157
#[no_mangle]
158
unsafe extern "C" fn switch() {
159
asm!(
160
"mov [rdi + 0x00], rsp",
161
"mov [rdi + 0x08], r15",
162
"mov [rdi + 0x10], r14",
163
"mov [rdi + 0x18], r13",
164
"mov [rdi + 0x20], r12",
165
"mov [rdi + 0x28], rbx",
166
"mov [rdi + 0x30], rbp",
167
"mov rsp, [rsi + 0x00]",
168
"mov r15, [rsi + 0x08]",
169
"mov r14, [rsi + 0x10]",
170
"mov r13, [rsi + 0x18]",
171
"mov r12, [rsi + 0x20]",
172
"mov rbx, [rsi + 0x28]",
173
"mov rbp, [rsi + 0x30]",
174
"ret", options(noreturn)
175
);
176
}
177
pub fn main() {
178
let mut runtime = Runtime::new();
179
runtime.init();
180
runtime.spawn(|| {
181
println!("THREAD 1 STARTING");
182
let id = 1;
183
for i in 0..10 {
184
println!("thread: {} counter: {}", id, i);
185
yield_thread();
186
}
187
println!("THREAD 1 FINISHED");
188
});
189
runtime.spawn(|| {
190
println!("THREAD 2 STARTING");
191
let id = 2;
192
for i in 0..15 {
193
println!("thread: {} counter: {}", id, i);
194
yield_thread();
195
}
196
println!("THREAD 2 FINISHED");
197
});
198
runtime.run();
199
}
Copied!
Copy link