Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Whisper in new concise nn.Module API #868

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

LeshengJin
Copy link
Contributor

@LeshengJin LeshengJin commented Sep 5, 2023

The first version of TVM Whisper. Try it out with python tests/python/test_model_whisper.py. A cuda device is required.

Need this pr(apache/tvm#15670) to be merged.

@MasterJH5574 MasterJH5574 force-pushed the main branch 2 times, most recently from 24949b0 to 58be070 Compare September 22, 2023 16:55
@raj-khare
Copy link

I get the following error when i run the test file:

  File "/root/run.py", line 600, in <module>
    main()
  File "/root/run.py", line 579, in main
    model = model.jit(spec=mod_spec, target=target, device="cuda", out_format="torch", debug=True)
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/nn/core.py", line 524, in jit
    spec, vm, params = _compile(spec, device, pipeline, debug)  # pylint: disable=invalid-name
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/nn/core.py", line 513, in _compile
    relax_build(
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 341, in build
    return _vmlink(
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 247, in _vmlink
    lib = tvm.build(
  File "/usr/local/lib/python3.10/dist-packages/tvm/driver/build_module.py", line 294, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff9b8bc598]
  [bt] (7) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x428) [0xffff9b8bd0b8]
  [bt] (6) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff9b8bc598]
  [bt] (5) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1c8) [0xffff9b8baeac]
  [bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13674) [0xffff9c293674]
  [bt] (3) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13294) [0xffff9c293294]
  [bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f1067c) [0xffff9c29067c]
  [bt] (1) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68) [0xffff9b5a36a8]
  [bt] (0) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30) [0xffff9d3fd050]
  Did you forget to bind?
    Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/opt/mlc-llm/3rdparty/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def matmul11(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(16), T.int64(1), T.int64(64)), "float32")):
    T.func_attr({"target": T.target({"arch": "sm_87", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "registers_per_block": 65536, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
    total_seq_len = T.int64()
    A = T.match_buffer(var_A, (T.int64(1), T.int64(16), T.int64(1), total_seq_len))
    B = T.match_buffer(var_B, (T.int64(1), T.int64(16), total_seq_len, T.int64(64)))
    for i1, i3, k in T.grid(T.int64(16), T.int64(64), total_seq_len):
        cse_var_1: T.int64 = i1 * T.int64(64) + i3
        matmul_1 = T.Buffer((T.int64(1024),), data=matmul.data)
        if k == T.int64(0):
            matmul_1[cse_var_1] = T.float32(0)
        A_1 = T.Buffer((total_seq_len * T.int64(16),), data=A.data)
        B_1 = T.Buffer((total_seq_len * T.int64(1024),), data=B.data)
        matmul_1[cse_var_1] = matmul_1[cse_var_1] + A_1[i1 * total_seq_len + k] * B_1[k * T.int64(64) + i1 * total_seq_len * T.int64(64) + i3]

@pra-dan
Copy link

pra-dan commented Aug 3, 2024

@raj-khare @LeshengJin even the mlc_chat import fails now. Is there a workaround? I think mlc_chat has been transformed into mlc_llm.

@prashant-dn
Copy link

prashant-dn commented Aug 4, 2024

I get the following error when i run the test file:

  File "/root/run.py", line 600, in <module>
    main()
  File "/root/run.py", line 579, in main
    model = model.jit(spec=mod_spec, target=target, device="cuda", out_format="torch", debug=True)
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/nn/core.py", line 524, in jit
    spec, vm, params = _compile(spec, device, pipeline, debug)  # pylint: disable=invalid-name
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/nn/core.py", line 513, in _compile
    relax_build(
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 341, in build
    return _vmlink(
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 247, in _vmlink
    lib = tvm.build(
  File "/usr/local/lib/python3.10/dist-packages/tvm/driver/build_module.py", line 294, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff9b8bc598]
  [bt] (7) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x428) [0xffff9b8bd0b8]
  [bt] (6) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff9b8bc598]
  [bt] (5) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1c8) [0xffff9b8baeac]
  [bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13674) [0xffff9c293674]
  [bt] (3) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13294) [0xffff9c293294]
  [bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f1067c) [0xffff9c29067c]
  [bt] (1) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68) [0xffff9b5a36a8]
  [bt] (0) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30) [0xffff9d3fd050]
  Did you forget to bind?
    Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/opt/mlc-llm/3rdparty/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def matmul11(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(16), T.int64(1), T.int64(64)), "float32")):
    T.func_attr({"target": T.target({"arch": "sm_87", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "registers_per_block": 65536, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
    total_seq_len = T.int64()
    A = T.match_buffer(var_A, (T.int64(1), T.int64(16), T.int64(1), total_seq_len))
    B = T.match_buffer(var_B, (T.int64(1), T.int64(16), total_seq_len, T.int64(64)))
    for i1, i3, k in T.grid(T.int64(16), T.int64(64), total_seq_len):
        cse_var_1: T.int64 = i1 * T.int64(64) + i3
        matmul_1 = T.Buffer((T.int64(1024),), data=matmul.data)
        if k == T.int64(0):
            matmul_1[cse_var_1] = T.float32(0)
        A_1 = T.Buffer((total_seq_len * T.int64(16),), data=A.data)
        B_1 = T.Buffer((total_seq_len * T.int64(1024),), data=B.data)
        matmul_1[cse_var_1] = matmul_1[cse_var_1] + A_1[i1 * total_seq_len + k] * B_1[k * T.int64(64) + i1 * total_seq_len * T.int64(64) + i3]

Hey @raj-khare I am also stuck at this error. Did you pass through?

@prashant-dn
Copy link

After replicating the env to the max possibility, I am stuck on this

Traceback (most recent call last):
  File "tests/python/test_model_whisper.py", line 176, in <module>
    main()
  File "tests/python/test_model_whisper.py", line 154, in main
    model = model.jit(spec=mod_spec, target=target, device="cuda", out_format="torch", debug=True)
  File "/Workspace/popo/miniconda3/envs/tvm-build-py38/lib/python3.8/site-packages/tvm-0.15.dev0-py3.8-linux-x86_64.egg/tvm/relax/frontend/nn/core.py", line 447, in jit
    relax.build(mod, target=target),
  File "/Workspace/popo/miniconda3/envs/tvm-build-py38/lib/python3.8/site-packages/tvm-0.15.dev0-py3.8-linux-x86_64.egg/tvm/relax/vm_build.py", line 327, in build
    return _vmlink(builder, target, tir_mod, ext_libs, params, system_lib=system_lib)
  File "/Workspace/popo/miniconda3/envs/tvm-build-py38/lib/python3.8/site-packages/tvm-0.15.dev0-py3.8-linux-x86_64.egg/tvm/relax/vm_build.py", line 241, in _vmlink
    lib = tvm.build(
  File "/Workspace/popo/miniconda3/envs/tvm-build-py38/lib/python3.8/site-packages/tvm-0.15.dev0-py3.8-linux-x86_64.egg/tvm/driver/build_module.py", line 281, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/Workspace/popo/miniconda3/envs/tvm-build-py38/lib/python3.8/site-packages/tvm-0.15.dev0-py3.8-linux-x86_64.egg/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  10: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::__mk_TVM22::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#1}>(tvm::__mk_TVM22::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  9: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  8: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
  7: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
  6: tvm::transform::Pass::operator()(tvm::IRModule) const
  5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: _ZN3tvm7runtime13PackedFun
  0: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::VerifyMemory()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::transform::VerifyMemory()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  Did you forget to bind?
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `T_transpose` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/home/popo/workspace/mlc/jinn/3rdparty/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def transpose11(A: T.Buffer((T.int64(51865), T.int64(1024)), "float32"), T_transpose: T.Buffer((T.int64(1024), T.int64(51865)), "float32")):
    T.func_attr({"op_pattern": 2, "target": T.target({"arch": "sm_61", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "registers_per_block": 65536, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
    for ax0, ax1 in T.grid(1024, 51865):
        T_transpose_1 = T.Buffer((T.int64(53109760),), data=T_transpose.data)
        A_1 = T.Buffer((T.int64(53109760),), data=A.data)
        T_transpose_1[ax0 * 51865 + ax1] = A_1[ax1 * 1024 + ax0]

@DevopsDood
Copy link

Any traction on this? Would be cool to see whisper support

@bil-ash
Copy link

bil-ash commented Nov 16, 2024

Yeah, same thing. Would like to see whisper support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants