Skip to content

Commit ae355c0

Browse files
authored
fix for auto_parallel (#10370)
1 parent d5275b7 commit ae355c0

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,8 @@ def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
136136
fast_ln = try_import("fast_ln")
137137
return fast_ln.fast_rms_norm(x_in, w, eps)[0]
138138
else:
139-
try:
140-
from paddle.incubate.nn.functional import fused_rms_norm
141-
142-
return fused_rms_norm(x=x_in, norm_weight=w, norm_bias=None, epsilon=eps, begin_norm_axis=2)[0]
143-
except ImportError:
144-
fused_ln = try_import("fused_ln")
145-
146-
return fused_ln.fused_rms_norm(x_in, w, eps)[0]
139+
fused_ln = try_import("fused_ln")
140+
return fused_ln.fused_rms_norm(x_in, w, eps)[0]
147141

148142

149143
def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):

0 commit comments

Comments
 (0)