Skip to content

Commit

Permalink
Merge branch 'master' into xc/dnn3.6_based_on_master
Browse files Browse the repository at this point in the history
  • Loading branch information
xczhai authored Dec 26, 2024
2 parents ce58e91 + 1ec91fc commit c868510
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,19 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& p
auto p_opt_convert = optional<v0::Convert>(p_max_context_len);
auto p_opt_reshape = optional<v1::Reshape>({p_opt_convert, any_input()});

// current seg len
auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, _const()});
auto p_shape_of = wrap_type<v3::ShapeOf>({p_unsqueeze});
// current seq len:
// it might be present in 2 different ways:
// input_ids -> unsqueeze -> reshape -> convert -> shape_of -> gather
// QKV -> variadic_split(Q or K) -> rope Q/K -> shape_of -> gather
// Probably we can use the symbols to re-use one of these ways.
// Currently, "any_input" is used to detect the both places.
auto p_shape_of = wrap_type<v3::ShapeOf>({any_input()});
auto p_current_len = wrap_type<v8::Gather>({p_shape_of, _const(), _const()});

auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_neg_const = wrap_type<v0::Constant>();
auto p_neg_mul = wrap_type<v1::Multiply>({p_current_len, p_neg_const});
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, p_neg_mul, _const(), _const(), _const()});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
} \
}

#define FUNC_LOAD_LEFTOVERS(inner, outer) unroll_for (uint lh = 0; lh < outer; ++lh) { \
const uint input_idx = INPUT0_GET_TILED_INDEX(INPUT0_TILED_ORDER); \
INPUTVTYPE read_data; \
unroll_for (uint lw = 0; lw < inner; ++lw) { \
read_data[lw] = input[input_idx + lw]; \
} \
unroll_for (uint lw = 0; lw < inner; ++lw) { \
const uint dst = local_buf_offset + lw; \
transpose_buf[dst][lh] = read_data[lw]; \
} \
}

#define FUNC_VSTORE(loop) unroll_for (uint lw = 0; lw < loop; ++lw) { \
const uint output_idx = output_idx_tile + (lw * x_pitch); \
VSTORE(TO_OUTPUTVTYPE(transpose_buf[local_buf_offset + lw]), 0, output + output_idx); \
Expand Down Expand Up @@ -109,7 +121,15 @@ KERNEL (reorder_data_bfyx_to_blocked_format)(

if (F_NO_REMAINDER_CONDITION) {
// read and transpose
#ifdef X_REMAINDER_CONDITION
if (X_NO_REMAINDER_CONDITION) {
FUNC_VLOAD(TILE_SIZE, TILE_SIZE)
} else {
FUNC_LOAD_LEFTOVERS(X_REMAINDER_SIZE, TILE_SIZE)
}
#else
FUNC_VLOAD(TILE_SIZE, TILE_SIZE)
#endif

// write to ddr
#ifdef X_REMAINDER_CONDITION
Expand All @@ -125,7 +145,15 @@ KERNEL (reorder_data_bfyx_to_blocked_format)(
#ifdef F_REMAINDER_CONDITION
else if (F_REMAINDER_CONDITION) {
// read and transpose
#ifdef X_REMAINDER_CONDITION
if (X_NO_REMAINDER_CONDITION) {
FUNC_VLOAD(TILE_SIZE, F_REMAINDER_SIZE)
} else {
FUNC_LOAD_LEFTOVERS(X_REMAINDER_SIZE, F_REMAINDER_SIZE)
}
#else
FUNC_VLOAD(TILE_SIZE, F_REMAINDER_SIZE)
#endif

// write to ddr
#ifdef X_REMAINDER_CONDITION
Expand Down
3 changes: 0 additions & 3 deletions tests/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ pytest>=5.0,<8.4
pytest-dependency==0.5.1
pytest-html==4.1.1
pytest-timeout==2.3.1
jax<=0.4.36
jaxlib<=0.4.36
kornia==0.7.0
networkx<=3.3
flax<=0.10.2

--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
Expand Down
2 changes: 0 additions & 2 deletions tests/layer_tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,3 @@ pytest
defusedxml
tensorflow
tensorflow-addons; python_version <= '3.10'
jax; sys_platform == "linux" and platform_machine == "x86_64" # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-cpu - wheels are for "x86_64" only
jaxlib; sys_platform == "linux" and platform_machine == "x86_64" # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-cpu - wheels are for "x86_64" only

0 comments on commit c868510

Please sign in to comment.