Skip to content

Commit d508732

Browse files
committed
fix tp
1 parent 522108c commit d508732

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

src/turbomind/triton_backend/llama/LlamaTritonModel.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,12 +333,6 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
333333
}
334334
else {
335335
moe_param_.method = ft::MoeParam::kFused;
336-
// Note: This will fail when GPUs of different SMs are mixed
337-
if (weight_type_ != ft::WeightType::kINT4 && ft::getSMVersion() >= 90) {
338-
// On sm90 the cuBLAS method may be faster as our grouped GEMM is not
339-
// optimized for GMMA yet
340-
moe_param_.method = ft::MoeParam::kNaive;
341-
}
342336
}
343337

344338
TM_LOG_INFO("%s", toString().c_str());
@@ -377,6 +371,10 @@ std::unique_ptr<ft::Engine<T>> LlamaTritonModel<T>::createSharedModelInstance(
377371
shared_state_,
378372
device_id);
379373

374+
// Wait for pinned buffers to be allocated for all ranks, otherwise tuning will hang
375+
// due to concurrent kernel launch & cudaMallocHost
376+
shared_state_->barrier->wait();
377+
380378
engine->Start();
381379

382380
return engine;

0 commit comments

Comments
 (0)