-
Notifications
You must be signed in to change notification settings - Fork 136
Description
Executive Summary (Conclusion)
I believe there is a logical contradiction in the importance_sampling function (defined in lotus/filters/cascade.py) when used with a small cascade_IS_max_sample_range (from lotus/types.py::CascadeArgs).
The core issue is: The function's "safety net"—the uniform sampling component (1 - is_weight) intended to sample low-score items—is completely nullified by the sample_range slicing.
This makes it impossible to sample any low-score items (e.g., items with a proxy score near 0.0). As a result, the learn_filter_cascade_thresholds function (which relies on this sampling) cannot be fed any high-confidence False examples, making it logically impossible to learn a valid neg_cascade_threshold (the lower-bound threshold).
Argument & Evidence (Step-by-Step Trace)
Here is the logical trace using the CascadeArgs parameters provided in the file and a hypothetical dataset.
1. Setup & Parameters:
- From
CascadeArgs(lotus/types.py):cascade_IS_max_sample_range: int = 200cascade_IS_weight: float = 0.9(This implies a 90% biased + 10% uniform sampling mix)
- Runtime Assumption:
len(proxy_scores) = 100_000(A large dataset)proxy_scoresis assumed to be sorted descending (this seems implied by the logic).
2. The Blended Probability w (Line 13 in importance_sampling):
w = is_weight * w / np.sum(w) + (1 - is_weight) * np.ones((len(proxy_scores))) / len(proxy_scores)- At this stage,
wis a list of 100,000 probabilities. - Crucially, the 10% uniform component (
(1 - is_weight) * ...) correctly assigns a non-zero probability to all items, including the low-score item at index90,000. This "safety net" is working as intended here.
3. The sample_range Calculation (Line 15):
sample_range = min(cascade_args.cascade_IS_max_sample_range, len(proxy_scores))- This evaluates to:
sample_range = min(200, 100_000) - The result is:
sample_range = 200
4. The Logical Flaw (Line 16):
sample_w = w[:sample_range]- This evaluates to:
sample_w = w[:200] - This line discards all calculated probabilities from index 201 to 100,000.
- The non-zero "safety net" probability assigned to the low-score item at index
90,000(from step 2) is now thrown away.
5. The Sampling Pool (Line 18):
indices = np.arange(sample_range)- This evaluates to:
indices = np.arange(200) - The sampling pool (
indices) physically only contains the top 200 high-score items. - The
np.random.choiceon line 21 will only ever pick numbers between 0 and 199.
The Consequence: Why This Breaks the Filter Logic
The learn_filter_cascade_thresholds function (in lotus/filters/cascade_utils.py) needs to learn both a pos_cascade_threshold (upper bound) and a neg_cascade_threshold (lower bound).
- To learn a
neg_cascade_threshold, the oracle (LLM) must be fed low-score samples to determine where theFalseboundary lies. - Because the
importance_samplingfunction (with these parameters) cannot provide any low-score samples, the system will only learn from high-score (likelyTrue) items. - This makes it impossible to learn a valid lower bound, breaking the core functionality of the cascade filter, which is to also filter out high-confidence
Falseitems.
This behavior seems to directly contradict the purpose of including the (1 - is_weight) uniform sampling component in the first place.
Suggestion:
This implementation seems to imply that cascade_IS_max_sample_range should almost always be equal to len(proxy_scores) for filtering. Perhaps this parameter (with a small value like 200) is intended for a different operation (like JOIN) and is being misused in the filter cascade?
Alternatively, the sampling logic may need to be revised to ensure the uniform "safety net" samples from the entire len(proxy_scores) range, not just the sample_range slice.
Thank you for your time and for this interesting library.