Skip to content

Commit

Permalink
[Dev] Refactor scripts based on our new directory structure (#69)
Browse files Browse the repository at this point in the history
* chore: Update support matrix in README

* Move bitblas package to root

* Remove unused code files

* Create soft link for tvm

* Create soft link for tvm

* Update softlink paths for tvm in setup.py

* Refactor import statements to use relative paths

* fix test linear

* Move bitblas package to root

* Move bitblas package to root
  • Loading branch information
LeiWang1999 committed Jul 4, 2024
1 parent 33869d4 commit f4e15a5
Show file tree
Hide file tree
Showing 45 changed files with 181 additions and 423 deletions.
14 changes: 8 additions & 6 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@

# installing tvm
install_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm", "python")
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, install_tvm_path)
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, install_tvm_path + "/python")

develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm", "python")
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, develop_tvm_path)
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, develop_tvm_path + "/python")

import tvm as tvm # noqa: E402
from . import gpu # noqa: F401
from .base import (
TileDevice, # noqa: F401
Expand All @@ -30,6 +31,7 @@
try_inline_contiguous_spatial, # noqa: F401
)


from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion bitblas/base/roller/arch/cpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import tvm
from bitblas import tvm
from tvm.target import Target
from .arch_base import TileDevice

Expand Down
2 changes: 1 addition & 1 deletion bitblas/base/roller/arch/cuda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import tvm
from bitblas import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Dict, Union
Expand Down
2 changes: 1 addition & 1 deletion bitblas/base/roller/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""PrimFunc Wrapper and Block information Analaysis"""

import tvm
from bitblas import tvm
from tvm import tir
from tvm.tir import IterVar, PrimFunc
from typing import Any, Dict, List, Tuple, Optional
Expand Down
2 changes: 1 addition & 1 deletion bitblas/base/roller/policy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Iterable, Dict, List, Optional

import numpy as np
import tvm
from bitblas import tvm

from ..arch import TileDevice
from ..bestfit import BestFit
Expand Down
2 changes: 1 addition & 1 deletion bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for tensorcore schedule"""
import tvm
from bitblas import tvm
from typing import Dict, List, Tuple, Optional
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion bitblas/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
import tempfile
import os.path as osp
import tvm
from bitblas import tvm
from tvm import tir
from tvm import meta_schedule as ms
from tvm.ir import IRModule
Expand Down
2 changes: 1 addition & 1 deletion bitblas/base/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import tvm
from bitblas import tvm
import os
from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind
from concurrent.futures import ThreadPoolExecutor, as_completed
Expand Down
2 changes: 1 addition & 1 deletion bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
from hashlib import sha256
import shutil
import tvm
from bitblas import tvm
from tvm.contrib.tar import tar
import logging

Expand Down
2 changes: 1 addition & 1 deletion bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from bitblas import tvm
from tvm.tir.function import TensorIntrin
from tvm.script import tir as T
from typing import Dict, Literal
Expand Down
2 changes: 1 addition & 1 deletion bitblas/gpu/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# pylint: disable=missing-docstring
"""A RMS norm schedule rule for GPU operators."""

import tvm
from bitblas import tvm
from tvm import tir
from tvm.tir import Block, BufferStore
from tvm.tir.expr import Cast, BufferLoad, Call
Expand Down
21 changes: 12 additions & 9 deletions bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,24 @@ def warmup(self, topk=20):
def forward(self, A, output=None):
if A.dtype != torch.float16:
A = A.half()
# can be lifted to post init.
self.init_params()

if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1))
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

A_void = ctypes.c_void_p(A.data_ptr())
stream_handle = ctypes.c_void_p(stream.cuda_stream)
# can be lifted to post init.
self.init_params()
args = [A_void, *self.q_params]
if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
args.append(ctypes.c_void_p(output.data_ptr()))
if self.bitblas_matmul.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)
args.append(stream_handle)
# m is the product of the last n - 1 dimensions of A
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m,
stream_handle)
self.bitblas_matmul.lib.call(*args)

return output

Expand Down
7 changes: 5 additions & 2 deletions bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from bitblas import tvm
from tvm.target import Target
import operator
from functools import reduce
Expand Down Expand Up @@ -89,7 +89,10 @@ def __legalize_propagate(self, propagate):
def __initialize_propagate(self, propagate_a: Optional[TransformKind],
propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and
if propagate_b is not None and propagate_b == TransformKind.NonTransform:
# Currently we do not support propagate_a when propagate_b is not transformed.
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
elif (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and
(self.K % MICRO_KERNEL_SIZE) == 0):
object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform)
else:
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/impl/batch_matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from bitblas import tvm
from tvm import te, DataType
from tvm.tir import IndexMap
from bitblas.ops.operator import TransformKind
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/impl/batch_matmul_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from bitblas import tvm
from tvm import te
from bitblas.ops.operator import TransformKind

Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/impl/convolution2d_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from bitblas import tvm
from tvm import te, tir


Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from bitblas import tvm
from tvm import te, DataType
from tvm.tir import IndexMap
from bitblas.ops.operator import TransformKind
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/impl/matmul_dequantize_splitk_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from bitblas import tvm
from tvm import te
from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16,
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/impl/matmul_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from bitblas import tvm
from tvm import te
from bitblas.gpu.matmul_analysis import get_propagate_map
from bitblas.ops.operator import TransformKind
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/impl/matmul_splitk_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from bitblas import tvm
from tvm import te
from bitblas.ops.operator import TransformKind

Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/matmul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from bitblas import tvm
import numpy as np
from tvm.target import Target
from bitblas.utils.tensor_adapter import tvm_tensor_to_torch
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/matmul_dequantize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from bitblas import tvm
from tvm.target import Target
from bitblas.base.roller.arch.cuda import CUDA
from typing import Any, List, Literal, Optional, Tuple, Union
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod
import tvm
from bitblas import tvm
from tvm import IRModule
from tvm.target import Target
from tvm.tir import PrimFunc
Expand Down
2 changes: 1 addition & 1 deletion bitblas/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""TIR computation utilities for quantization."""

import tvm
from bitblas import tvm
from tvm import tir


Expand Down
3 changes: 2 additions & 1 deletion bitblas/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest
from bitblas.base import DefaultPolicy, TensorCorePolicy
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags

from bitblas import tvm # pylint: disable=import-error
from tvm.testing.utils import *

# pytest.main() wrapper to allow running single test file
def main():
Expand Down
2 changes: 1 addition & 1 deletion bitblas/utils/tensor_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from bitblas import tvm
from typing import Union
from enum import IntEnum
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion bitblas/wrapper/general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from bitblas import tvm
from typing import Optional, List, Dict, Union
from tvm import IRModule
from bitblas import TileDevice
Expand Down
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ echo "set(USE_LLVM llvm-config-10)" >> config.cmake && echo "set(USE_CUDA ON)" >
cmake .. && make -j && cd ../../..

echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc
echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python:\$PYTHONPATH" >> ~/.bashrc
echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc

source ~/.bashrc
6 changes: 3 additions & 3 deletions maint/scripts/installation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# install torch
pip install torch==2.1.0
# install requirements
pip install -r requirements.txt

# install llvm
apt-get install llvm-10
Expand All @@ -21,6 +21,6 @@ echo "set(USE_LLVM llvm-config-10)" >> config.cmake && echo "set(USE_CUDA ON)" >
cmake .. && make -j && cd ../../..

echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc
echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python:\$PYTHONPATH" >> ~/.bashrc
echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc

source ~/.bashrc
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def run(self):
build_tvm(llvm_path)
# Continue with the standard installation process
install.run(self)
# Create softlink for bitblas
create_softlink(tvm_path="../3rdparty/tvm/python/tvm", bitblas_path="bitblas/tvm")


class BitBLASBuilPydCommand(build_py):
Expand All @@ -222,6 +224,9 @@ def run(self):
_, llvm_path = setup_llvm_for_tvm()
# Build TVM
build_tvm(llvm_path)
# Create softlink for bitblas
create_softlink(tvm_path="../3rdparty/tvm/python/tvm", bitblas_path="bitblas/tvm")

# Copy the built TVM to the package directory
TVM_PREBUILD_ITEMS = [
"3rdparty/tvm/build/libtvm_runtime.so",
Expand Down
Loading

0 comments on commit f4e15a5

Please sign in to comment.