Skip to content

Commit ba15477

Browse files
Setup pre-commit hooks. (#1)
1 parent daa769e commit ba15477

14 files changed

+73
-33
lines changed

.github/workflows/pre-commit.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: pre-commit
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [main]
7+
8+
jobs:
9+
pre-commit:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v3
13+
- uses: actions/setup-python@v3
14+
- uses: pre-commit/[email protected]

.pre-commit-config.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# See https://pre-commit.com for more information
2+
# See https://pre-commit.com/hooks.html for more hooks
3+
repos:
4+
- repo: https://github.com/pre-commit/pre-commit-hooks
5+
rev: v3.2.0
6+
hooks:
7+
- id: trailing-whitespace
8+
- id: end-of-file-fixer
9+
- id: check-yaml
10+
- id: check-added-large-files
11+
- repo: https://github.com/psf/black
12+
rev: 22.10.0
13+
hooks:
14+
- id: black

LICENSE

-1
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,3 @@ conflicts with the conditions of the GPLv2, you may retroactively and
216216
prospectively choose to deem waived or otherwise exclude such Section(s) of
217217
the License, but only in their entirety and only with respect to the Combined
218218
Software.
219-

README.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
**WARNING: This is an early preview that is in progress. It is not ready for
44
general use.**
55

6+
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit)
7+
8+
69
## Development Getting Started
710

8-
Use this as a guide to get started developing the project using pinned,
11+
Use this as a guide to get started developing the project using pinned,
912
pre-release dependencies. You are welcome to deviate as you see fit, but
1013
these canonical directions mirror what the CI does.
1114

@@ -54,3 +57,9 @@ pip install -e shortfin
5457
pytest sharktank
5558
pytest shortfin
5659
```
60+
61+
### Optional: Pre-commits and developer settings
62+
63+
This project is set up to use the `pre-commit` tooling. To install it in
64+
your local repo, run: `pre-commit install`. After this point, when making
65+
commits locally, hooks will run. See https://pre-commit.com/

requirements.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,8 @@ types-requests==2.31.0.20240125
77

88
# It is expected that you have installed a PyTorch version/variant specific
99
# to your needs, so we only include a minimum version spec.
10-
torch>=2.3.0
10+
# TODO: Use a versioned release once 2.3.0 drops.
11+
torch>=2.3.0.dev1
12+
13+
# Used for managing pre-commit flows.
14+
pre-commit

sharktank/sharktank/layers/kv_cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
These are not complete abstractions: they are primarily focused on making
1010
tightly coupled transformer blocks a bit less "stringy" with loose tensors
11-
and dims floating around everywhere.
11+
and dims floating around everywhere.
1212
"""
1313

1414
import abc

sharktank/sharktank/ops/templates/mmt_block_scaled_offset_q4_unsigned.mlir

+5-5
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}
3939
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
4040
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
4141
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
42-
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
43-
iterator_types = ["parallel", "parallel", "parallel"] }}
42+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
43+
iterator_types = ["parallel", "parallel", "parallel"] }}
4444
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type)
4545
outs(%b_grouped : !b_grouped_tensor_type) {{
4646
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
@@ -63,9 +63,9 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}
6363
indexing_maps = [
6464
// d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r)
6565
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
66-
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
67-
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
68-
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }}
66+
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
67+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
68+
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }}
6969
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
7070
outs(%result_fill : !c_tensor_type) {{
7171
^bb0(%a_element: !a_type, %b_element: !a_type, %out: !a_type):

sharktank/sharktank/ops/templates/mmt_block_scaled_q8_3d.mlir

+6-6
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ util.func private @sharktank_mmt_block_scaled_q8_3d_{n}_{k}_{bs}_{a_type}(
3232
%b_grouped_dequant = linalg.generic {{
3333
indexing_maps = [
3434
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
35-
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
36-
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
37-
iterator_types = ["parallel", "parallel", "parallel"] }}
35+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
36+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
37+
iterator_types = ["parallel", "parallel", "parallel"] }}
3838
ins(%d, %qs : !d_tensor_type, !qs_tensor_type)
3939
outs(%b_grouped : !b_grouped_tensor_type) {{
4040
^bb0(%d_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
@@ -55,9 +55,9 @@ util.func private @sharktank_mmt_block_scaled_q8_3d_{n}_{k}_{bs}_{a_type}(
5555
indexing_maps = [
5656
// d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r)
5757
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
58-
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
59-
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
60-
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }}
58+
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
59+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
60+
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }}
6161
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
6262
outs(%result_fill : !c_tensor_type) {{
6363
^bb0(%a_element: !a_type, %b_element: !a_type, %out: !a_type):

sharktank/sharktank/ops/templates/mmt_super_block_scaled_offset_q4_unsigned_3d.mlir

+7-7
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
module {{
2626

2727
util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_count}_{sub_count}_{bs}_{a_type}(
28-
%a: !a_tensor_type,
29-
%d: !d_tensor_type,
28+
%a: !a_tensor_type,
29+
%d: !d_tensor_type,
3030
%dmin: !dmin_tensor_type,
3131
%sb_scales_hi_i8: !sb_hi_i8_type,
3232
%sb_scales_low_i8: !sb_low_i8_type,
@@ -59,11 +59,11 @@ util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_cou
5959
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, // sb_mins_hi[n, sup, sub]
6060
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, // sb_mins_low[n, sup, sub]
6161
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // out b_grouped[n, sup, sub, bs]
62-
],
62+
],
6363
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }}
6464
ins(
6565
%qs, %d, %dmin, %sb_scales_hi, %sb_scales_low, %sb_mins_hi, %sb_mins_low :
66-
!qs_tensor_type, !d_tensor_type, !dmin_tensor_type,
66+
!qs_tensor_type, !d_tensor_type, !dmin_tensor_type,
6767
!sb_hi_i2_type, !sb_low_i4_type, !sb_hi_i2_type, !sb_low_i4_type
6868
)
6969
outs(%b_grouped : !b_grouped_tensor_type) {{
@@ -74,7 +74,7 @@ util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_cou
7474
%shift_4 = arith.constant 4 : i32
7575
%d_element_ext = arith.extf %d_element : !scale_type to !a_type
7676
%dmin_element_ext = arith.extf %dmin_element : !scale_type to !a_type
77-
77+
7878
// Combine sub-block scale.
7979
%sb_scale_low_i32 = arith.extui %sb_scales_low_element : i4 to i32
8080
%sb_scale_hi_i32 = arith.extui %sb_scales_hi_element : i2 to i32
@@ -111,8 +111,8 @@ util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_cou
111111
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)>, // aexp
112112
affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>, // b_grouped_dequant
113113
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> // out
114-
],
115-
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] }}
114+
],
115+
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] }}
116116
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
117117
outs(%result_fill : !c_tensor_type) {{
118118
^bb0(%a_element: !a_type, %b_element: !a_type, %out: !a_type):

sharktank/sharktank/ops/templates/mmtfp_2d.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ util.func private @sharktank_mmtfp_2d_{n}_{k}_{a_type}{bT_type}(
2020
%c1 = arith.constant 1 : index
2121
%m = tensor.dim %a, %c0 : !a_tensor_type
2222
%result_empty = tensor.empty(%m) : !c_tensor_type
23-
%result_init = linalg.fill
24-
ins(%zero : {a_type})
23+
%result_init = linalg.fill
24+
ins(%zero : {a_type})
2525
outs(%result_empty: !c_tensor_type) -> !c_tensor_type
2626
%result = linalg.matmul_transpose_b
2727
ins (%a, %bT: !a_tensor_type, !bT_tensor_type)

sharktank/sharktank/ops/templates/mmtfp_3d.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ util.func private @sharktank_mmtfp_3d_{n}_{k}_{a_type}{bT_type}(
3434
linalg.yield %in : !bT_type
3535
}} -> !bT_broadcast_tensor_type
3636
%result_empty = tensor.empty(%b0, %m) : !c_tensor_type
37-
%result_init = linalg.fill
38-
ins(%zero : !a_type)
37+
%result_init = linalg.fill
38+
ins(%zero : !a_type)
3939
outs(%result_empty: !c_tensor_type) -> !c_tensor_type
4040
%result = linalg.batch_matmul_transpose_b
4141
ins (%a, %bT_broadcast: !a_tensor_type, !bT_broadcast_tensor_type)

shortfin/shortfin/framework/session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
* DeviceSession: A single HAL device and other process-level globals. Shared global
1212
memory and corresponding synchronization handles are accessible from here.
13-
* WorkQueue: Logical stream of execution, nested under the DeviceSession. Each
13+
* WorkQueue: Logical stream of execution, nested under the DeviceSession. Each
1414
queue holds a timeline semaphore which sequences invocations. For these models,
1515
we route workloads of vastly different characteristics to distinct queues (i.e.
1616
prefill vs decode step).

shortfin/shortfin/llm/config.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
"""Configuration objects.
88
99
Parameters that are intrinsic to a specific model.
10-
10+
1111
In a typical transformer model, the KV cache is organized similar to (mapped to
1212
our parameter names below):
13-
k = tensor.empty(transformer_block_count, batch_size, seq,
13+
k = tensor.empty(transformer_block_count, batch_size, seq,
1414
attn_head_count, attn_head_dim)
1515
v = ...
1616
@@ -28,9 +28,9 @@
2828
2929
In this scenario, we declare that one block holds the KV cache for all transformer
3030
block layers because it reduces the accounting. As such, for the above example,
31-
a single position in the sequence will be 524,288 bytes, assuming a 2-byte element
32-
type. If we choose to block by block_stride=16 positions, each block will be 8MiB.
33-
Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536
31+
a single position in the sequence will be 524,288 bytes, assuming a 2-byte element
32+
type. If we choose to block by block_stride=16 positions, each block will be 8MiB.
33+
Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536
3434
blocks for a total number of sequence positions of 24,576.
3535
3636
These are well-known numbers but are derived above to give a sense of scale.

version_info.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"package-version": "0.1.dev2"}
1+
{"package-version": "0.1.dev3"}

0 commit comments

Comments
 (0)