Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

formula 22 in DeepSeek V3 technical report #238

Open
guoyejun opened this issue Jan 7, 2025 · 3 comments
Open

formula 22 in DeepSeek V3 technical report #238

guoyejun opened this issue Jan 7, 2025 · 3 comments

Comments

@guoyejun
Copy link

guoyejun commented Jan 7, 2025

thanks for the great model.

I have one question about formula 22 below, could you help, thanks.

image

suppose k=1, that's MTP module 1 in the red circle of below figure. And T is 4 in the example. So, T-k=4-1=3, and so h(1:T-k) is h(1:3).

My question is why it is 1:3, not 1:4? From the figure below, finally there are 4 outputs of MTP module 1. Is it a typo of h(1:T)?

image

@guoyejun guoyejun changed the title formular 22 in DeepSeek V3 technical report formula 22 in DeepSeek V3 technical report Jan 7, 2025
@chuhac
Copy link

chuhac commented Jan 10, 2025

Perhaps I have to disagree with your claim and let me express my opinion.

The $h$ here denotes a [b, s, h] tensor here while the slicing and indices are applied on the s dimension, if you check carefully with the CausalLM loss computation in code repo like Megatron-LM or huggingface transformers, there implementation is like (pseudo code)

loss = LLM(input_ids=input_ids, labels=labels).loss

while in its internal loss function, they will do a shift to ensure they use tokens before a specific token to predict it, but for simplify the thinking procedure, we can see if you wanna do next-token prediction, your input_ids and labels should be the same.

If you agree with this, let's proceed to the MTP scenario, in a batch of tokens, if you want to use the exactly the same input_ids to predict labels (actually that's all you can do), but for Next$^2$ token prediction, you should do

loss = TRM(input_ids=input_ids[:-1], labels[1:])

In this implementation, the first token of input_ids now is aligned with the second token of labels, rethink the inner shift process, you are actually using the first hidden states to predict the third token

And for Next$^k$ token prediction you need to shift more, thus making the formula like

$h_{1:T-k}^k = TRM_k(h_{1:T-k}^{\prime k})$

@yejunguo
Copy link

not know why we need to refer to the loss computation code, it is just a simple formula of input/output of a transformer block.

For simplicity, let's use k=1 and T=4, to align with the figure below.

From formula 21, we get four outputs, h1', h2', h3' and h4', with its corresponding input t2, t3, t4 and t5 from the bottom of the figure.
image

With formula 22, the expected outputs are h1, h2, h3 and h4, with:
h1 = f(h1')
h2 = f(h1', h2')
h3 = f(h1', h2', h3')
h4 = f(h1', h2', h3', h4')

My concern is that h(1:3) in formula 22 does not present the above meanings.

image

@chuhac
Copy link

chuhac commented Jan 10, 2025

I understand your claim.

The main point of my explanation is that your input_ids and labels get shorter with the depth k increases. Specifically, if you take the paper's picture as an example, you need to refer to T=7 from my point of view. In such a case, the hiddens states on the left in this case should each be written as:

(shifted)
depth=0, input: hidden states(t1, ... , t6)  labels: t2, ..., t7 
depth=1, input: hidden states(t1, ... , t5) + embedding from(t2, ..., t6)  labels: t3, ..., t7
depth=2, input: hidden states(t1, ... , t4) + embedding from(t3, ..., t6) labels: t4, ..., t7

In this figure, the tokens on the left and right sides are omitted, what really happened was what I described. The formula 22 is actually describing the condition that you need to shift the hidden states as $k$ increases, i.e., hidden states(t1, ... , t4) when depth=2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants