Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the Q-shift operation in the code #37

Open
kkkkkk123-ops opened this issue Sep 13, 2024 · 1 comment
Open

About the Q-shift operation in the code #37

kkkkkk123-ops opened this issue Sep 13, 2024 · 1 comment

Comments

@kkkkkk123-ops
Copy link

in the jit_func in Class VRWKV_SpatialMix_V6 why we need to -x after the shift_func? It seems not -x in the _inner_forward in Class VRWKV_ChannelMix when calculate xx.
#Class VRWKV_SpatialMix_V6
def jit_func(self, x, patch_resolution):
# Mix x with the previous timestep to produce xk, xv, xr
B, T, C = x.size()

    xx = self.shift_func(x, self.shift_pixel, patch_resolution=patch_resolution, 
                         with_cls_token=self.with_cls_token) - x
    xxx = x + xx * self.time_maa_x  # [B, T, C]

#Class VRWKV_ChannelMix
def _inner_forward(x):
xx = self.shift_func(x, self.shift_pixel, patch_resolution=patch_resolution,
with_cls_token=self.with_cls_token)

@duanduanduanyuchen
Copy link
Collaborator

Hi! The subtraction in the interpolation in vrwkv6 gives a formula equivalent to that in VRWKV.
VRWKV6:
xx = shift(x) - x
xxx = x + \mu * xx
= x + \mu * (shift(x) - x)
= (1 - \mu) * x + \mu * shift(x)

VRWKV:
xx = shift(x)
xxx = \mu * x + (1 - \mu) * xx
= \mu * x + (1 - \mu) * shift(x)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants