-
Notifications
You must be signed in to change notification settings - Fork 649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Applying the jetson fixes #847
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2409,7 +2409,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd | |
} | ||
|
||
|
||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) | ||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, int8_t *out_col_normed, int8_t *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why this is needed, but as long as it compiles on all platforms (looking at you, MSVC :) ), I don't see a problem with the change either .IIRC, int8_t is exactly 8 bits, while char is at least 8 bits |
||
{ | ||
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD | ||
// Each thread reads the same column but multiple rows | ||
|
@@ -2431,15 +2431,15 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S | |
|
||
typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf; | ||
__shared__ typename LoadHalf::TempStorage loadhalf; | ||
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8; | ||
typedef cub::BlockStore<int8_t, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8; | ||
__shared__ typename StoreInt8::TempStorage storeint8; | ||
|
||
__shared__ float smem_row_stats[TILE_ROWS]; | ||
__shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; | ||
|
||
half local_data[ITEMS_PER_THREAD]; | ||
float local_col_stats[ITEMS_PER_THREAD]; | ||
char local_quantized_data[ITEMS_PER_THREAD]; | ||
int8_t local_quantized_data[ITEMS_PER_THREAD]; | ||
|
||
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread) | ||
#pragma unroll ITEMS_PER_THREAD | ||
|
@@ -2489,11 +2489,11 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S | |
} | ||
else | ||
{ | ||
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); | ||
local_quantized_data[j] = (int8_t)(rintf(__half2float(local_data[j])*row_stat)); | ||
} | ||
} | ||
else | ||
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); | ||
local_quantized_data[j] = (int8_t)(rintf(__half2float(local_data[j])*row_stat)); | ||
} | ||
|
||
StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); | ||
|
@@ -2504,7 +2504,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S | |
{ | ||
// we already pre-normalized the col/row stat: | ||
// what this does is float/absmax*127 = int8 | ||
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); | ||
local_quantized_data[j] = (int8_t)(rintf(__half2float(local_data[j])*local_col_stats[j])); | ||
} | ||
|
||
__syncthreads(); | ||
|
@@ -3832,8 +3832,8 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>( | |
|
||
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); | ||
|
||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); | ||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); | ||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, int8_t *out_col_normed, int8_t *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); | ||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, int8_t *out_col_normed, int8_t *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); | ||
|
||
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); | ||
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,9 @@ FORCE_INLINE int popcnt32(int x32) | |
|
||
#if defined(USE_AVX) || defined(USE_AVX2) | ||
#include <immintrin.h> | ||
#elif defined __aarch64__ | ||
#warning "--- THIS IS AARCH64" | ||
#include <sse2neon.h> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are going to need to support Neon one way or the other. I am pondering if this is the right approach though, or if we should implement the Neon intrinsics directly? If it saves us time in the short run, maybe a viable option? |
||
#else | ||
#include <emmintrin.h> | ||
#ifdef USE_SSE41 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we confirm that the cmake file works with the Jetson devices? It compiles, but I do not have a device to test with.
Wheels can be taken from the latest build from here
https://github.com/TimDettmers/bitsandbytes/actions/workflows/python-package.yml