Skip to content

Commit 10aad54

Browse files
author
Mark-ZhouWX
committed
workaround of layernorm2d
1 parent 69dd016 commit 10aad54

File tree

1 file changed

+9
-3
lines changed
  • official/cv/segment-anything/segment_anything/modeling

1 file changed

+9
-3
lines changed

official/cv/segment-anything/segment_anything/modeling/common.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,14 @@ def __init__(self, num_channels: int, epsilon: float = 1e-6) -> None:
3030
self.eps = epsilon
3131

3232
def construct(self, x: ms.Tensor) -> ms.Tensor:
33-
u = x.mean(1, keep_dims=True)
34-
s = (x - u).pow(2).mean(1, keep_dims=True)
35-
x = (x - u) / ops.sqrt(s + self.eps)
33+
bs, c, h, w = x.shape
34+
x = x.reshape(bs, c, -1).swapaxes(1, 2) # (bs, c, h, w) -> (bs, hw, c)
35+
36+
u = x.mean(-1, keep_dims=True) # (bs, hw, 1)
37+
s = (x - u).pow(2).mean(-1, keep_dims=True) # (bs, hw, 1)
38+
x = (x - u) / ops.sqrt(s + self.eps) # (bs, hw, c)
39+
40+
x = x.swapaxes(1, 2).reshape(bs, c, h, w)
41+
3642
x = self.weight[:, None, None] * x + self.bias[:, None, None]
3743
return x

0 commit comments

Comments
 (0)