Skip to content

Commit de6b0f4

Browse files
committed
rust: introduce current
This allows Rust code to get a reference to the current task without having to increment the refcount, but still guaranteeing memory safety. Cc: Ingo Molnar <[email protected]> Cc: Peter Zijlstra <[email protected]> Signed-off-by: Wedson Almeida Filho <[email protected]>
1 parent 1f8f11d commit de6b0f4

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

rust/helpers.c

+6
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ bool rust_helper_refcount_dec_and_test(refcount_t *r)
100100
}
101101
EXPORT_SYMBOL_GPL(rust_helper_refcount_dec_and_test);
102102

103+
struct task_struct *rust_helper_get_current(void)
104+
{
105+
return current;
106+
}
107+
EXPORT_SYMBOL_GPL(rust_helper_get_current);
108+
103109
void rust_helper_get_task_struct(struct task_struct *t)
104110
{
105111
get_task_struct(t);

rust/kernel/prelude.rs

+2
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,5 @@ pub use super::error::{code::*, Error, Result};
3636
pub use super::{str::CStr, ThisModule};
3737

3838
pub use super::init::{InPlaceInit, Init, PinInit};
39+
40+
pub use super::current;

rust/kernel/task.rs

+92-1
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,60 @@
55
//! C header: [`include/linux/sched.h`](../../../../include/linux/sched.h).
66
77
use crate::bindings;
8-
use core::{cell::UnsafeCell, ptr};
8+
use core::{cell::UnsafeCell, marker::PhantomData, ops::Deref, ptr};
9+
10+
/// Returns the currently running task.
11+
#[macro_export]
12+
macro_rules! current {
13+
() => {
14+
// SAFETY: Deref + addr-of below create a temporary `TaskRef` that cannot outlive the
15+
// caller.
16+
unsafe { &*$crate::task::Task::current() }
17+
};
18+
}
919

1020
/// Wraps the kernel's `struct task_struct`.
1121
///
1222
/// # Invariants
1323
///
1424
/// Instances of this type are always ref-counted, that is, a call to `get_task_struct` ensures
1525
/// that the allocation remains valid at least until the matching call to `put_task_struct`.
26+
///
27+
/// # Examples
28+
///
29+
/// The following is an example of getting the PID of the current thread with zero additional cost
30+
/// when compared to the C version:
31+
///
32+
/// ```
33+
/// let pid = current!().pid();
34+
/// ```
35+
///
36+
/// Getting the PID of the current process, also zero additional cost:
37+
///
38+
/// ```
39+
/// let pid = current!().group_leader().pid();
40+
/// ```
41+
///
42+
/// Getting the current task and storing it in some struct. The reference count is automatically
43+
/// incremented when creating `State` and decremented when it is dropped:
44+
///
45+
/// ```
46+
/// use kernel::{task::Task, types::ARef};
47+
///
48+
/// struct State {
49+
/// creator: ARef<Task>,
50+
/// index: u32,
51+
/// }
52+
///
53+
/// impl State {
54+
/// fn new() -> Self {
55+
/// Self {
56+
/// creator: current!().into(),
57+
/// index: 0,
58+
/// }
59+
/// }
60+
/// }
61+
/// ```
1662
#[repr(transparent)]
1763
pub struct Task(pub(crate) UnsafeCell<bindings::task_struct>);
1864

@@ -25,6 +71,24 @@ unsafe impl Sync for Task {}
2571
type Pid = bindings::pid_t;
2672

2773
impl Task {
74+
/// Returns a task reference for the currently executing task/thread.
75+
///
76+
/// # Safety
77+
///
78+
/// Callers must ensure that the returned [`TaskRef`] doesn't outlive the current task/thread.
79+
pub unsafe fn current<'a>() -> TaskRef<'a> {
80+
// SAFETY: Just an FFI call with no additional safety requirements.
81+
let ptr = unsafe { bindings::get_current() };
82+
83+
TaskRef {
84+
// SAFETY: If the current thread is still running, the current task is valid. Given
85+
// that `TaskRef` is not `Send`, we know it cannot be transferred to another thread
86+
// (where it could potentially outlive the caller).
87+
task: unsafe { &*ptr.cast() },
88+
_not_send: PhantomData,
89+
}
90+
}
91+
2892
/// Returns the group leader of the given task.
2993
pub fn group_leader(&self) -> &Task {
3094
// SAFETY: By the type invariant, we know that `self.0` is valid.
@@ -69,3 +133,30 @@ unsafe impl crate::types::AlwaysRefCounted for Task {
69133
unsafe { bindings::put_task_struct(obj.cast().as_ptr()) }
70134
}
71135
}
136+
137+
/// A wrapper for a shared reference to [`Task`] that isn't [`Send`].
138+
///
139+
/// We make this explicitly not [`Send`] so that we can use it to represent the current thread
140+
/// without having to increment/decrement the task's reference count.
141+
///
142+
/// # Invariants
143+
///
144+
/// The wrapped [`Task`] remains valid for the lifetime of the object.
145+
pub struct TaskRef<'a> {
146+
task: &'a Task,
147+
_not_send: PhantomData<*mut ()>,
148+
}
149+
150+
impl Deref for TaskRef<'_> {
151+
type Target = Task;
152+
153+
fn deref(&self) -> &Self::Target {
154+
self.task
155+
}
156+
}
157+
158+
impl From<TaskRef<'_>> for crate::types::ARef<Task> {
159+
fn from(t: TaskRef<'_>) -> Self {
160+
t.deref().into()
161+
}
162+
}

0 commit comments

Comments
 (0)