Skip to content

Commit

Permalink
[Misc][MoE] add Deepseek-V3 moe tuning support (#12558)
Browse files Browse the repository at this point in the history
Signed-off-by: Divakar Verma <[email protected]>
  • Loading branch information
divakar-amd authored Jan 30, 2025
1 parent e0cc5f2 commit 1c1bb0b
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
def main(args: argparse.Namespace):
print(args)

config = AutoConfig.from_pretrained(args.model)
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
Expand All @@ -461,6 +462,11 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV3ForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral.
E = config.num_local_experts
Expand Down Expand Up @@ -538,6 +544,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()

main(args)

0 comments on commit 1c1bb0b

Please sign in to comment.