Skip to content

perf: optimize TokenNumFilter with batch tokenization#929

Open
JohnGiorgi wants to merge 1 commit intodatajuicer:mainfrom
JohnGiorgi:perf/token-num-filter-batch-tokenization
Open

perf: optimize TokenNumFilter with batch tokenization#929
JohnGiorgi wants to merge 1 commit intodatajuicer:mainfrom
JohnGiorgi:perf/token-num-filter-batch-tokenization

Conversation

@JohnGiorgi
Copy link
Contributor

@JohnGiorgi JohnGiorgi commented Mar 3, 2026

Summary

Switch TokenNumFilter from per-sample tokenizer.tokenize() to batched tokenizer(texts, add_special_tokens=False), and enable _batched_op = True.

Problem

TokenNumFilter processes each sample individually via compute_stats_single, calling tokenizer.tokenize(text) through the get_words_from_document wrapper. This has three sources of overhead:

  • No batching: _batched_op is not set, so the framework processes one sample at a time (batch_size=1)
  • String conversion: .tokenize() returns subword strings (e.g. ['▁Hello', '▁world']); the string conversion is wasted when we only need a count
  • Per-sample overhead: get_model() lookup + get_words_from_document() wrapper called per sample

Fix

  • Set _batched_op = True to enable framework batching (batches of 1000, matching DEFAULT_BATCH_SIZE)
  • Replace compute_stats_singlecompute_stats_batched: collect untokenized texts, call tokenizer(texts, add_special_tokens=False) once per batch, store len(input_ids) per sample
  • Replace process_singleprocess_batched: list comprehension over batch stats
  • Remove unused get_words_from_document import

Follows the established _batched_op pattern used by other filters (e.g. WordsNumFilter, TextLengthFilter). add_special_tokens=False is critical for backward compatibility: tokenizer.tokenize() does not add special tokens, so we match that behavior. Token counts verified identical on 100K synthetic samples.

Impact

~1.15x throughput on 100K synthetic paragraph-length samples (1,350 → 1,555 samples/sec). Not massive, but makes a difference at scale, e.g. this is projected to save ~49min for a dataset of size 500_000K examples.

Test plan

  • pytest tests/ops/filter/test_token_num_filter.py
  • Token count equivalence verified on 100K synthetic samples (old tokenizer.tokenize() vs new tokenizer(texts, add_special_tokens=False)['input_ids'])
  • Pre-commit hooks pass

Switch TokenNumFilter from per-sample tokenizer.tokenize() to batched
tokenizer(texts, add_special_tokens=False), and enable _batched_op for
framework-level batching. This yields ~1.15x throughput improvement on
100K samples (1,350 → 1,555 samples/sec) by:

- Amortizing get_model() and Python call overhead across batches of 1000
- Returning int IDs directly instead of converting to string tokens
- Enabling the framework's batched map path via _batched_op = True

Token counts are identical between old and new approaches, verified on
100K synthetic samples. add_special_tokens=False matches the behavior
of the previous tokenizer.tokenize() call.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the TokenNumFilter by transitioning its tokenization and processing logic from a per-sample approach to a batched one. This change leverages framework-level batching, leading to improved throughput without altering existing configurations or introducing new parameters.

Highlights

  • Batch Processing: Enabled framework-level batching for TokenNumFilter by setting _batched_op = True, improving processing efficiency.
  • Tokenization Logic: Switched from per-sample tokenizer.tokenize() to batched tokenizer(texts, add_special_tokens=False) for token count calculation.
  • Method Refactoring: Replaced compute_stats_single with compute_stats_batched and process_single with process_batched to support batch operations.
  • Code Cleanup: Removed the unused import of get_words_from_document.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • data_juicer/ops/filter/token_num_filter.py
    • Removed the import of get_words_from_document.
    • Added _batched_op = True to enable batch processing.
    • Replaced compute_stats_single with compute_stats_batched to handle token count calculation for multiple samples simultaneously.
    • Replaced process_single with process_batched for efficient batch-wise filtering.
Activity
  • The TokenNumFilter unit tests passed successfully.
  • Token count equivalence between the old and new methods was verified on 100K synthetic samples.
  • All pre-commit hooks (black, flake8, isort) passed.
  • The pull request was generated with Claude Code.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively optimizes TokenNumFilter by introducing batched tokenization, which is a significant performance improvement. The implementation is clean and follows existing patterns in the codebase. I have one suggestion to improve the robustness of the new batched implementation by handling non-string inputs, which could be a regression from the previous version.

Comment on lines +59 to +62
for idx, stat in enumerate(samples_stats):
if StatsKeys.num_token not in stat:
indices.append(idx)
texts.append(samples_list[idx])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The previous implementation used get_words_from_document, which likely handled non-string or None values in the text field gracefully (e.g., by treating them as empty documents, resulting in a token count of 0). This new batched implementation directly appends samples_list[idx] to the texts list. If samples_list contains any non-string values, the tokenizer call will raise an error. To improve robustness and prevent potential regressions, it's important to handle non-string inputs. A safe way to do this is to assign a token count of 0 to non-string samples, mimicking the likely behavior of the previous implementation.

Suggested change
for idx, stat in enumerate(samples_stats):
if StatsKeys.num_token not in stat:
indices.append(idx)
texts.append(samples_list[idx])
for idx, stat in enumerate(samples_stats):
if StatsKeys.num_token not in stat:
text = samples_list[idx]
if isinstance(text, str):
indices.append(idx)
texts.append(text)
else:
# Treat non-string samples as having 0 tokens
samples_stats[idx][StatsKeys.num_token] = 0

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is true:

def get_words_from_document(
document,
token_func=None,
new_line=True,
tab=True,
):
"""
Get words from a document. Useful to compute ratios, like the
stopwords ratio.
:param document: document that need to split words.
:param token_func: function of tokenizer, if specified, the function
will be used for split document into different tokens.
:param new_line: whether to use '\\\\n' to split words.
:param tab: whether to use '\\\\t' to split words.
:return: word list obtained from document
"""
if token_func:
words = token_func(document)
else:
words = split_on_whitespace(document, new_line, tab)
return words

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think there is an upstream filter for empty strings:

# 2. filter out those samples with empty or None text
# TODO: optimize the filtering operation for better efficiency
logger.info(f"There are {len(dataset)} sample(s) in the original dataset.")
def non_empty_text(sample, target_keys):
for target_key in target_keys:
# TODO: case for CFT, in which the len(sample[target_key]) == 0
if sample[target_key] is None:
# we filter out the samples contains at least None column
# since the op can not handle it now
return False
return True
dataset = dataset.filter(non_empty_text, num_proc=num_proc, fn_kwargs={"target_keys": text_keys})
logger.info(f"{len(dataset)} samples left after filtering empty text.")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant