diff --git a/kernel/sched/ext.c b/kernel/sched/ext.c index 93e041e2f8d75..55a9c4fd208db 100644 --- a/kernel/sched/ext.c +++ b/kernel/sched/ext.c @@ -5183,11 +5183,11 @@ extern struct btf *btf_vmlinux; static const struct btf_type *task_struct_type; static u32 task_struct_type_id; -/* Make the 2nd argument of .dispatch a pointer that can be NULL. */ -static bool promote_dispatch_2nd_arg(int off, int size, - enum bpf_access_type type, - const struct bpf_prog *prog, - struct bpf_insn_access_aux *info) +static bool promote_op_nth_arg(int arg_n, const char *op, + int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) { struct btf *btf = bpf_get_btf_vmlinux(); const struct bpf_struct_ops_desc *st_ops_desc; @@ -5196,6 +5196,10 @@ static bool promote_dispatch_2nd_arg(int off, int size, u32 btf_id, member_idx; const char *mname; + /* struct_ops op args are all sequential, 64-bit numbers */ + if (off != arg_n * sizeof(__u64)) + return false; + /* btf_id should be the type id of struct sched_ext_ops */ btf_id = prog->aux->attach_btf_id; st_ops_desc = bpf_struct_ops_find(btf, btf_id); @@ -5217,14 +5221,7 @@ static bool promote_dispatch_2nd_arg(int off, int size, member = &btf_type_member(t)[member_idx]; mname = btf_name_by_offset(btf_vmlinux, member->name_off); - /* - * Check if it is the second argument of the function pointer at - * "dispatch" in struct sched_ext_ops. The arguments of struct_ops - * operators are sequential and 64-bit, so the second argument is at - * offset sizeof(__u64). - */ - if (strcmp(mname, "dispatch") == 0 && - off == sizeof(__u64)) { + if (!strcmp(mname, op)) { /* * The value is a pointer to a type (struct task_struct) given * by a BTF ID (PTR_TO_BTF_ID). It is trusted (PTR_TRUSTED), @@ -5245,6 +5242,15 @@ static bool promote_dispatch_2nd_arg(int off, int size, return false; } +static bool promote_op_arg(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + return promote_op_nth_arg(1, "dispatch", off, size, type, prog, info) || + promote_op_nth_arg(1, "yield", off, size, type, prog, info); +} + static bool bpf_scx_is_valid_access(int off, int size, enum bpf_access_type type, const struct bpf_prog *prog, @@ -5252,7 +5258,7 @@ static bool bpf_scx_is_valid_access(int off, int size, { if (type != BPF_READ) return false; - if (promote_dispatch_2nd_arg(off, size, type, prog, info)) + if (promote_op_arg(off, size, type, prog, info)) return true; if (off < 0 || off >= sizeof(__u64) * MAX_BPF_FUNC_ARGS) return false; diff --git a/tools/testing/selftests/sched_ext/maybe_null.bpf.c b/tools/testing/selftests/sched_ext/maybe_null.bpf.c index ad5e694226bbf..27d0f386acfb1 100644 --- a/tools/testing/selftests/sched_ext/maybe_null.bpf.c +++ b/tools/testing/selftests/sched_ext/maybe_null.bpf.c @@ -18,9 +18,19 @@ void BPF_STRUCT_OPS(maybe_null_success_dispatch, s32 cpu, struct task_struct *p) vtime_test = p->scx.dsq_vtime; } +bool BPF_STRUCT_OPS(maybe_null_success_yield, struct task_struct *from, + struct task_struct *to) +{ + if (to) + bpf_printk("Yielding to %s[%d]", to->comm, to->pid); + + return false; +} + SEC(".struct_ops.link") struct sched_ext_ops maybe_null_success = { .dispatch = maybe_null_success_dispatch, + .yield = maybe_null_success_yield, .enable = maybe_null_running, .name = "minimal", }; diff --git a/tools/testing/selftests/sched_ext/maybe_null.c b/tools/testing/selftests/sched_ext/maybe_null.c index 3f26b784f9c57..31cfafb0cf65a 100644 --- a/tools/testing/selftests/sched_ext/maybe_null.c +++ b/tools/testing/selftests/sched_ext/maybe_null.c @@ -7,13 +7,15 @@ #include #include #include "maybe_null.bpf.skel.h" -#include "maybe_null_fail.bpf.skel.h" +#include "maybe_null_fail_dsp.bpf.skel.h" +#include "maybe_null_fail_yld.bpf.skel.h" #include "scx_test.h" static enum scx_test_status run(void *ctx) { struct maybe_null *skel; - struct maybe_null_fail *fail_skel; + struct maybe_null_fail_dsp *fail_dsp; + struct maybe_null_fail_yld *fail_yld; skel = maybe_null__open_and_load(); if (!skel) { @@ -22,10 +24,17 @@ static enum scx_test_status run(void *ctx) } maybe_null__destroy(skel); - fail_skel = maybe_null_fail__open_and_load(); - if (fail_skel) { - maybe_null_fail__destroy(fail_skel); - SCX_ERR("Should failed to open and load maybe_null_fail skel"); + fail_dsp = maybe_null_fail_dsp__open_and_load(); + if (fail_dsp) { + maybe_null_fail_dsp__destroy(fail_dsp); + SCX_ERR("Should failed to open and load maybe_null_fail_dsp skel"); + return SCX_TEST_FAIL; + } + + fail_yld = maybe_null_fail_yld__open_and_load(); + if (fail_yld) { + maybe_null_fail_yld__destroy(fail_yld); + SCX_ERR("Should failed to open and load maybe_null_fail_yld skel"); return SCX_TEST_FAIL; } diff --git a/tools/testing/selftests/sched_ext/maybe_null_fail.bpf.c b/tools/testing/selftests/sched_ext/maybe_null_fail_dsp.bpf.c similarity index 93% rename from tools/testing/selftests/sched_ext/maybe_null_fail.bpf.c rename to tools/testing/selftests/sched_ext/maybe_null_fail_dsp.bpf.c index 1607fe07bead1..c0641050271d3 100644 --- a/tools/testing/selftests/sched_ext/maybe_null_fail.bpf.c +++ b/tools/testing/selftests/sched_ext/maybe_null_fail_dsp.bpf.c @@ -21,5 +21,5 @@ SEC(".struct_ops.link") struct sched_ext_ops maybe_null_fail = { .dispatch = maybe_null_fail_dispatch, .enable = maybe_null_running, - .name = "minimal", + .name = "maybe_null_fail_dispatch", }; diff --git a/tools/testing/selftests/sched_ext/maybe_null_fail_yld.bpf.c b/tools/testing/selftests/sched_ext/maybe_null_fail_yld.bpf.c new file mode 100644 index 0000000000000..3c1740028e3b9 --- /dev/null +++ b/tools/testing/selftests/sched_ext/maybe_null_fail_yld.bpf.c @@ -0,0 +1,28 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (c) 2024 Meta Platforms, Inc. and affiliates. + */ + +#include + +char _license[] SEC("license") = "GPL"; + +u64 vtime_test; + +void BPF_STRUCT_OPS(maybe_null_running, struct task_struct *p) +{} + +bool BPF_STRUCT_OPS(maybe_null_fail_yield, struct task_struct *from, + struct task_struct *to) +{ + bpf_printk("Yielding to %s[%d]", to->comm, to->pid); + + return false; +} + +SEC(".struct_ops.link") +struct sched_ext_ops maybe_null_fail = { + .yield = maybe_null_fail_yield, + .enable = maybe_null_running, + .name = "maybe_null_fail_yield", +};