Skip to content

Commit

Permalink
Merge branch 'k230' of github.com:ucb-bar/Baremetal-NN into main
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Sep 14, 2024
2 parents 9144b10 + 2adc3e5 commit a784006
Show file tree
Hide file tree
Showing 11 changed files with 1,337 additions and 109 deletions.
12 changes: 9 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
########################################################################################################################
cmake_minimum_required(VERSION 3.10)

project(nn LANGUAGES C)
project(nn LANGUAGES C ASM)

# Options
option(BUILD_SHARED_LIBS "Build using shared libraries" OFF )
Expand All @@ -27,6 +27,8 @@ option(AVX "Use AVX implementation" OFF )
option(RVV "Use RISCV vector extension" OFF )
option(ZVFH "Use RISCV half-precision floating-point vector extension" OFF)

option(RVV_ASM "Use RISCV vector extension in assembly implementation" OFF)

add_compile_options(-O1 -Wall -Wextra)

add_library(target-x86 INTERFACE)
Expand All @@ -46,7 +48,7 @@ target_compile_definitions(target-riscv INTERFACE RISCV)

set(WRAP_SPECS_FILE "htif_wrap.specs")
set(SPECS_FILE "htif_nano.specs")
set(SPEC_FLAGS -specs=${SPECS_FILE} -specs=${WRAP_SPECS_FILE})
# set(SPEC_FLAGS -specs=${SPECS_FILE} -specs=${WRAP_SPECS_FILE})

set(MARCH "rv64gc")
set(MABI "lp64d")
Expand All @@ -55,18 +57,22 @@ set(MCMODEL "medany")
# generate march flags
if (RVV)
list(APPEND MARCH "v")
list(APPEND MARCH "_zicntr")
# list(APPEND MARCH "_zicntr")

if (ZVFH)
list(APPEND MARCH "_zfh")
list(APPEND MARCH "_zvfh")
endif()
endif()
if (RVV_ASM)
list(APPEND MARCH "v")
endif()

list(JOIN MARCH "" MARCH)

if (NOT DEFINED LINKER_SCRIPT)
set(LINKER_SCRIPT ${CMAKE_SOURCE_DIR}/toolchain/htif.ld)
# set(LINKER_SCRIPT ${CMAKE_SOURCE_DIR}/toolchain/k230.ld)
endif()

target_compile_options(target-riscv INTERFACE -fno-common -fno-builtin-printf)
Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ cmake --build ./build/ --target all
spike --extension=gemmini --misaligned ./build/tests/tests.elf
```

### Building for K230 board

first, we clean any previous builds

```bash
rm -rf ./build/
```

```bash
cmake . -D CMAKE_TOOLCHAIN_FILE=./k230-gcc.cmake -S ./ -B ./build/ -G "Unix Makefiles" -D CMAKE_BUILD_TYPE=Debug -D RVV_ASM=ON
cmake --build ./build/ --target all
```

### Cleaning build files

```
Expand Down
19 changes: 19 additions & 0 deletions k230-gcc.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

# set the RISCV option to ON
option(RISCV "Build for RISC-V" ON)

# CMake toolchain definition for RISC-V GCC toolchain
set(CMAKE_SYSTEM_NAME "Linux" CACHE STRING "")
set(CMAKE_SYSTEM_PROCESSOR "k230" CACHE STRING "")

set(TOOLCHAIN_PREFIX "riscv64-unknown-linux-musl-")

set(CMAKE_C_COMPILER "${TOOLCHAIN_PREFIX}gcc")
set(CMAKE_ASM_COMPILER "${TOOLCHAIN_PREFIX}gcc")
set(CMAKE_CXX_COMPILER "${TOOLCHAIN_PREFIX}g++")
set(CMAKE_AR "${TOOLCHAIN_PREFIX}ar")
set(CMAKE_LINKER "${TOOLCHAIN_PREFIX}ld")
set(CMAKE_OBJCOPY "${TOOLCHAIN_PREFIX}objcopy")
set(CMAKE_SIZE "${TOOLCHAIN_PREFIX}size")
set(CMAKE_STRIP "${TOOLCHAIN_PREFIX}ld")

142 changes: 142 additions & 0 deletions nn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
cmake_minimum_required(VERSION 3.10)

set(cpu_impl
./impl/cpu/abs.c
./impl/cpu/acc.c
./impl/cpu/acc1.c
./impl/cpu/add.c
./impl/cpu/add1.c
./impl/cpu/div.c
./impl/cpu/dot.c
./impl/cpu/fill.c
./impl/cpu/max.c
./impl/cpu/maximum.c
./impl/cpu/maximum1.c
./impl/cpu/min.c
./impl/cpu/minimum.c
./impl/cpu/minimum1.c
./impl/cpu/mul.c
./impl/cpu/mul1.c
./impl/cpu/neg.c
./impl/cpu/norm.c
./impl/cpu/rms_norm.c
./impl/cpu/sgn.c
./impl/cpu/softmax.c
./impl/cpu/sqr.c
./impl/cpu/sqrt.c
./impl/cpu/sub.c
./impl/cpu/sum.c
./impl/cpu/transpose.c
)


if (AVX)
message(STATUS "Using AVX implementation")
add_compile_definitions(AVX)
endif ()

if (RVV)
message(STATUS "Using RVV implementation")
add_compile_definitions(RVV)

if (RISCV_ZVFH)
message(STATUS "Using Zvfh extension")
add_compile_definitions(RISCV_ZVFH)
endif ()

set(rvv_impl
./impl/rvv/abs.c
./impl/rvv/acc.c
./impl/rvv/acc1.c
./impl/rvv/add.c
./impl/rvv/add1.c
./impl/rvv/div.c
./impl/rvv/dot.c
./impl/rvv/max.c
./impl/rvv/maximum.c
./impl/rvv/maximum1.c
./impl/rvv/min.c
./impl/rvv/minimum.c
./impl/rvv/minimum1.c
./impl/rvv/mul.c
./impl/rvv/mul1.c
./impl/rvv/neg.c
./impl/rvv/rms_norm.c
./impl/rvv/sub.c
./impl/rvv/transpose.c
)
endif ()

if (RVV_ASM)
message(STATUS "Using RVV assembly implementation")

set(rvv_impl
./impl/rvv/abs.S
./impl/rvv/add.S
./impl/rvv/dot.S
)
endif ()

if (GEMMINI)
message(STATUS "Using Gemmini implementation")
add_compile_definitions(GEMMINI)

set(gemmini_impl
impl/gemmini/mm.c
)
endif ()


add_library(nn
./functional/nn_tensor_creation.c
./functional/nn_print.c
./functional/nn_abs.c
./functional/nn_add.c
./functional/nn_batch_norm2d.c
./functional/nn_conv2d.c
./functional/nn_clip.c
./functional/nn_copy.c
./functional/nn_div.c
./functional/nn_elu.c
./functional/nn_fill.c
./functional/nn_interpolate.c
./functional/nn_layer_norm.c
./functional/nn_linear.c
./functional/nn_matmul.c
./functional/nn_mm.c
./functional/nn_norm.c
./functional/nn_max.c
./functional/nn_maximum.c
./functional/nn_max_pool2d.c
./functional/nn_min.c
./functional/nn_minimum.c
./functional/nn_mul.c
./functional/nn_mv.c
./functional/nn_neg.c
./functional/nn_relu.c
./functional/nn_relu6.c
./functional/nn_rms_norm.c
./functional/nn_softmax.c
./functional/nn_silu.c
./functional/nn_sub.c
./functional/nn_sum.c
./functional/nn_transpose.c

${rvv_impl}
${gemmini_impl}
${cpu_impl}
)

target_include_directories(nn PUBLIC ./)

if (X86)
message(STATUS "nn: Building for x86")
target_link_libraries(nn target-x86)

elseif (RISCV)
message(STATUS "nn: Building for RISC-V")
target_link_libraries(nn target-riscv)
endif ()


target_link_libraries(nn m)
18 changes: 18 additions & 0 deletions src/ops/rvv/abs.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

.globl NN__abs_f32
NN__abs_f32:
beqz a0,__abs_f32_exit
slli a4,a4,0x2
slli a2,a2,0x2
__abs_f32_loop:
vsetvli a5,a0,e32,m1,ta,ma
vlse32.v v24,(a3),a4
vfabs.v v24,v24
vsse32.v v24,(a1),a2
slli a6,a5,0x2
add a3,a3,a6
add a1,a1,a6
sub a0,a0,a5
bnez a0,__abs_f32_loop
__abs_f32_exit:
ret
21 changes: 21 additions & 0 deletions src/ops/rvv/add.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

.globl NN__add_f32
NN__add_f32:
beqz a0,__add_f32_exit
slli a4,a4,0x2
slli a6,a6,0x2
slli a2,a2,0x2
__add_f32_loop:
vsetvli a7,a0,e32,m1,ta,ma
vlse32.v v24,(a3),a4
vlse32.v v25,(a5),a6
vfadd.vv v24,v24,v25
vsse32.v v24,(a1),a2
slli t1,a7,0x2
add a3,a3,t1
add a5,a5,t1
add a1,a1,t1
sub a0,a0,a7
bnez a0,__add_f32_loop
__add_f32_exit:
ret
19 changes: 19 additions & 0 deletions src/ops/rvv/add1.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

.globl NN__add1_f32
NN__add1_f32:
beqz a0,__add1_f32_exit
slli a4,a4,0x2
slli a2,a2,0x2
__add1_f32_loop:
vsetvli a5,a0,e32,m1,ta,ma
vlse32.v v24,(a3),a4
vfmv.v.f v25,fa0
vfadd.vv v24,v24,v25
vsse32.v v24,(a1),a2
slli a6,a5,0x2
add a3,a3,a6
add a1,a1,a6
sub a0,a0,a5
bnez a0,__add1_f32_loop
__add1_f32_exit:
ret
25 changes: 25 additions & 0 deletions src/ops/rvv/dot.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

.globl NN__dot_f32
NN__dot_f32:
vsetvli t1,zero,e32,m1,ta,ma
vmv.v.i v27,0
vmv1r.v v24,v27
beqz a0,__dot_f32_exit
slli a3,a3,0x2
slli a5,a5,0x2
__dot_f32_loop:
vsetvli a6,a0,e32,m1,ta,ma
vlse32.v v26,(a2),a3
vlse32.v v25,(a4),a5
vfmacc.vv v24,v26,v25
slli a7,a6,0x2
add a2,a2,a7
add a4,a4,a7
sub a0,a0,a6
bnez a0,__dot_f32_loop
vsetvli t1,zero,e32,m1,ta,ma
__dot_f32_exit:
vfredusum.vs v24,v24,v27
vfmv.f.s fa5,v24
fsw fa5,0(a1)
ret
Loading

0 comments on commit a784006

Please sign in to comment.