-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththreads.c
178 lines (137 loc) · 4.49 KB
/
threads.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include "interrupt.h"
#include "threads.h"
#include "util.h"
#include "buddy_allocator.h"
#include "assert.h"
#include "lock.h"
#define MAX_THREADS (1 << 16)
typedef enum {NOT_STARTED = 0, RUNNING, FINISHED, WAIT_JOIN} thread_state;
struct thread {
void* stack_pointer;
void* stack_start;
int cnt_log_page;
void * ret_val;
thread_state state;
};
static volatile pid_t current_thread = 1;
static volatile pid_t previous_thread = 1;
struct threads_pool {
volatile pid_t first_free;
volatile struct thread threads[MAX_THREADS];
volatile pid_t next[MAX_THREADS];
volatile pid_t prev[MAX_THREADS];
};
static struct threads_pool tp;
void init_threads() {
tp.first_free = 2;
for (int i = 2; i < MAX_THREADS - 1; ++i) {
tp.next[i] = i + 1;
}
tp.next[1] = tp.prev[1] = 1;
tp.threads[1].state = RUNNING;
tp.threads[0].state = NOT_STARTED;
}
volatile struct thread *get_free_thread() {
pid_t first = tp.first_free;
tp.first_free = tp.next[first];
tp.next[first] = tp.next[1];
tp.prev[first] = 1;
tp.prev[tp.next[1]] = first;
tp.next[1] = first;
return &tp.threads[first];
}
pid_t create_thread(void* (*fptr)(void *), void *arg) {
start_critical_section();
volatile struct thread *new_thread = get_free_thread();
new_thread->cnt_log_page = 10;
new_thread->stack_start = get_page0(new_thread->cnt_log_page);
new_thread->stack_pointer = (uint8_t *)new_thread->stack_start + PAGE_SIZE * (1 << (new_thread->cnt_log_page));
struct init_thread_data {
uint64_t r15, r14, r13, r12, rbx, rbp;
void* start_thread_addr;
void* fun_addr;
void* arg;
};
new_thread->stack_pointer = (uint8_t *)new_thread->stack_pointer - sizeof(struct init_thread_data);
struct init_thread_data* init_val = new_thread->stack_pointer;
init_val->r12 = 0;
init_val->r13 = 0;
init_val->r14 = 0;
init_val->r15 = 0;
init_val->rbx = 0;
init_val->rbp = 0;
extern void *start_thread;
init_val->start_thread_addr = &start_thread;
init_val->fun_addr = fptr;
init_val->arg = arg;
new_thread->state = RUNNING;
end_critical_section();
return (pid_t)(new_thread - tp.threads);
}
void switch_threads(void **old_sp, void *new_sp);
void check_thread_finished() {
//printf("check thread fin %d\n", previous_thread);
volatile struct thread* thread = tp.threads + previous_thread;
if (thread->state == FINISHED && previous_thread != current_thread) {
free_page(thread->stack_start, thread->cnt_log_page);
thread->state = WAIT_JOIN;
}
}
void run_thread(pid_t tid) {
if (current_thread == tid) {
return;
}
struct thread *thread = (struct thread*)tp.threads + tid;
int ot = current_thread;
current_thread = tid;
previous_thread = ot;
struct thread *othread = (struct thread*)tp.threads + ot;
//printf("before switch from %d to %d\n", ot, tid);
switch_threads(&othread->stack_pointer, thread->stack_pointer);
//printf("after switch from %d to %d, pr_t %d\n", ot, tid, previous_thread);
check_thread_finished();
}
void finish_current_thread(void* val) {
start_critical_section();
int ct = get_current_thread();
//printf("thread finish %d\n", ct);
volatile struct thread* current_t = tp.threads + ct;
current_t->state = FINISHED;
current_t->ret_val = val;
//printf("%d %d %d\n", ct, tp.next[ct], tp.prev[ct]);
tp.prev[tp.next[ct]] = tp.prev[ct];
tp.next[tp.prev[ct]] = tp.next[ct];
end_critical_section();
yield();
assert(0);
}
void yield() {
start_critical_section();
//printf("%d try change thread\n", get_current_thread());
for (pid_t i = tp.next[current_thread];; i = tp.next[current_thread]) {
//printf("i = %d, tp.threads[i].state = %d\n", i, tp.threads[i].state);
if (i == 0 || tp.threads[i].state != RUNNING) continue;
//printf("%d\n", i);
run_thread(i);
//printf("continue in %d, after change %d\n", get_current_thread(),i);
break;
}
end_critical_section();
}
void thread_join(pid_t thread, void** retval) {
while (tp.threads[thread].state != WAIT_JOIN) {
yield();
barrier();
}
if (retval) {
*retval = tp.threads[thread].ret_val;
}
tp.threads[thread].state = NOT_STARTED;
start_critical_section();
tp.next[thread] = tp.first_free;
tp.first_free = thread;
end_critical_section();
}
pid_t get_current_thread() {
return current_thread;
}