Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embedding: adjust n_ubatch value, print error on insufficient n_batch value #6296

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/embedding/embedding.cpp
Expand Up @@ -61,6 +61,8 @@ int main(int argc, char ** argv) {
}

params.embedding = true;
// For BERT models, batch size must be equal to ubatch size
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
params.n_ubatch = params.n_batch;

print_build_info();

Expand Down Expand Up @@ -114,7 +116,9 @@ int main(int argc, char ** argv) {
for (const auto & prompt : prompts) {
auto inp = ::llama_tokenize(ctx, prompt, true, false);
if (inp.size() > n_batch) {
inp.resize(n_batch);
fprintf(stderr, "%s: error: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
__func__, (long long int) inp.size(), (long long int) n_batch);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can use %ld instead of %lld, no need to cast the type then

Suggested change
__func__, (long long int) inp.size(), (long long int) n_batch);
__func__, inp.size(), n_batch);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied & rolled back due to build failure.
IIRC, this is why I used %lld with casting in #6193. Although I tried your suggestion just in case 😉.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We normally use PRIu64 / PRId64 to print 64-bit integers. Alternatively, in this case just %d and cast to (int) is fine

return 1;
}
inputs.push_back(inp);
}
Expand Down