-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A bunch of improvements for the classification skill #50
base: master
Are you sure you want to change the base?
A bunch of improvements for the classification skill #50
Conversation
1. Make number of rows to print a configurable parameter (default to 5). 2. Add parameter to enable index column printing. This implementation prints index of the original data frame if the passed data frame is derived.
Otherwise we don't know whether it improved the accuracy or not. We should eventually introduce smarter learning strategies. E.g. simple ones like not accepting changes that make accuracy worse. Or complex ones with genetic algorithms like in FunSearch.
Accuracy threshold was ignored and unused originally which made training quite difficult in practical scenarios.
When providing feedback to the model, mention which output is wrong. Otherwise the model doesn't have enough information which of the outputs is correct/incorrect.
1. Phrase the prompt in more imperative manner liked by GPT models. 2. Instruct the teacher model to avoid unnecessary rephrasing of the prompt. With GPT-4 this makes it to make a lot less unnecessary changes. When a skill has multiple outputs, each skill output rewrite also changes wording of all the other outputs unnecessary distorting and degrading their performance. This phrasing significantly reduces such distortion but doesn't remove it completely. Running training cycles on each skill output separately solves this completely but is much much slower. Another potential solutions (I haven't tried it yet) is to collect feedback for all outputs and apply it all in a single go. More testing with real data is needed here.
@@ -207,9 +207,9 @@ def get_feedback( | |||
[gt_pred_match.rename("match"), gt], axis=1 | |||
) | |||
pred_feedback[pred_column] = match_concat.apply( | |||
lambda row: "Prediction is correct." | |||
lambda row: f"Prediction for {gt_column} is correct." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the reason for this addition? the initial idea was to use a single column at a time, so pointing out a specific column name might be not necessary - but I'm probably missing your idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Classification (and Transform) skills support multiple outputs and I used this to classify each social media post into multiple categories (each one as a True/False field).
You're correct that on each step we evaluate only one output. But the base prompt doesn't mention which of the outputs we evaluate at the step. This patch is the easiest way I found to make sure the model understands that the feedback is related to a specific output.
If there is only one output, we can simply say "Prediction is correct." as it used to be.
Here is how the template from TransformSkill.improve(). Note that it doesn't mention anything about the output name:
"""
## Current prompt
{self.instructions}
## Examples
{examples}
Summarize your analysis about incorrect predictions and suggest changes to the prompt."""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got your idea about referencing specific columns to help LLM make the correct assessment. However, the column name defined there doesn't contain any signal for LLM, example from the tests:
Prediction for gt_0 is incorrect. Correct answer: 0 0 0 1 1 1 1 5 1
"1 1 1"
"gt_0" keyword is not presented in input prompt which consists of the string "Input: ... Output: ...". In this case, I'd better create a string like
"Prediction for the field "Output" is incorrect"
assuming there can be multiple outputs.
Let me know if it makes sense.
Happy to merge your PR as soon as we have all tests passed. Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
I just need to figure out how to get the field name 🤔
Hi, @chemeris ! thanks for your great contribution!
Absolutely, it would be very helpful to have different learning strategies and reasoning path, and give user the options. Feel free to open the github issue where we can discuss the solutions and reference it in PR hereafter |
@@ -207,9 +207,9 @@ def get_feedback( | |||
[gt_pred_match.rename("match"), gt], axis=1 | |||
) | |||
pred_feedback[pred_column] = match_concat.apply( | |||
lambda row: "Prediction is correct." | |||
lambda row: f"Prediction for {gt_column} is correct." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got your idea about referencing specific columns to help LLM make the correct assessment. However, the column name defined there doesn't contain any signal for LLM, example from the tests:
Prediction for gt_0 is incorrect. Correct answer: 0 0 0 1 1 1 1 5 1
"1 1 1"
"gt_0" keyword is not presented in input prompt which consists of the string "Input: ... Output: ...". In this case, I'd better create a string like
"Prediction for the field "Output" is incorrect"
assuming there can be multiple outputs.
Let me know if it makes sense.
Happy to merge your PR as soon as we have all tests passed. Thank you.
Thank you for an excellent library. I ran it against our dataset of 10k social media posts to detect their sentiment and classify it into a set of topics. This set of patches is what it took for me to get it working - from minor fixes to improvements in the learning process.
I have a bunch of ideas on how to improve learning performance with more advanced learning strategies - I'd be happy to discuss this if there is interest in implementing advanced learning strategies.
PS Let me know if you want me to break this into smaller PRs.