Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing support for LoRA chat, tools, and completions data formats #189

Open
ronaldmannak opened this issue Jan 30, 2025 · 7 comments
Open

Comments

@ronaldmannak
Copy link
Contributor

ronaldmannak commented Jan 30, 2025

The loadLoRAData(url:) method currently only seems to support the text data format, e.g. {"text": "This is an example for the model."}. See here.

I was planning to add support for all data formats MLX supports (besides text: chat, tools, and completions)

Before I proceed with implementing loading the the missing formats, I would like to confirm a couple of points:

  1. Is the lack of support for these formats simply due to an oversight in the existing code to load different data formats since the example only uses text?
  2. Alternatively, is there a limitation with LoRATrain.train(model:, train:, validate: ...) that restricts it to handling only text data formats?

Edit: I am referring to the data formats as described in mlx-examples

@davidkoski
Copy link
Collaborator

I was planning to add support for all data formats MLX supports (besides text: chat, tools, and completions)

Sounds great!

Before I proceed with implementing loading the the missing formats, I would like to confirm a couple of points:

  1. Is the lack of support for these formats simply due to an oversight in the existing code to load different data formats since the example only uses text?

The code it was ported from only did text:

so just a limitation in the examples I had when porting.

  1. Alternatively, is there a limitation with LoRATrain.train(model:, train:, validate: ...) that restricts it to handling only text data formats?

There might be limitations around that but I think it would be just not expecting different types of data rather than a limitation of MLX itself -- that should work fine.

One thing to consider is where the LoraTrain.swift file should live -- currently it is in MLXLLM but it might make sense to go to MLXLMCommon

@ronaldmannak
Copy link
Contributor Author

Thanks @davidkoski I'll start adding the other file formats and I'll see if I run into any issues with LoRATrain along the way.

Re: moving LorTrain.swift to MLXLMCommon, I can definitely do that, but will that break projects that use LoRA you think?

@ronaldmannak
Copy link
Contributor Author

What data format is LoRATrain.train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:) expec in the train and validate parameters?

From this line in LoadJSON(url:) I understand that in the case of a text data structure, only the text values are passed, instead of a stringified JSON object. I guess that makes sense in the case of text since the JSON object only stores a single value (text).

However, the other three data structures (chat, tools, completions as described here) store multiple properties.

What kind of encoding is LoRATrain.train(model:...) expecting?

 return try String(contentsOf: url)
        .components(separatedBy: .newlines)
        .filter {
            $0.first == "{"
        }
        .compactMap {
            try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
        }

@davidkoski
Copy link
Collaborator

Thanks @davidkoski I'll start adding the other file formats and I'll see if I run into any issues with LoRATrain along the way.

Re: moving LorTrain.swift to MLXLMCommon, I can definitely do that, but will that break projects that use LoRA you think?

It may, but it will be very minor -- just import MLXLMCommon (which they may already be doing). We can document this.

@davidkoski
Copy link
Collaborator

What data format is LoRATrain.train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:) expec in the train and validate parameters?

From this line in LoadJSON(url:) I understand that in the case of a text data structure, only the text values are passed, instead of a stringified JSON object. I guess that makes sense in the case of text since the JSON object only stores a single value (text).

However, the other three data structures (chat, tools, completions as described here) store multiple properties.

What kind of encoding is LoRATrain.train(model:...) expecting?

return try String(contentsOf: url)
.components(separatedBy: .newlines)
.filter {
$0.first == "{"
}
.compactMap {
try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
}

The format is JSONL -- a file containing many JSON blocks. I am not aware of any native swift parser for it, so this is what we have. It may not stand up to a more complicated structure.

Perhaps the way to think about it is what data structure should the LORA training call take? It might be something like:

protocol LORAInput {
    func lmInput() async throws -> LMInput
    func target() async throws -> String
}

and the LoRA training would take an array of these. This would encapsulate anything from simple text to something that had images or video. The training loop already has to prepare the prompt, so this would cut out a few layers there.

Anyway, this could then be independent of the file format.

@ronaldmannak
Copy link
Contributor Author

@davidkoski Sorry, the JSONL part I understand and actually already have working. I reused your approach to split the file on newlines and then decode a single json object per line. From the training data I've seen so far, that works.

I've just created a draft pull request for the update I've made so far.

From your comment I understand we don't really have trainers for the different data formats, and we'll need to create separate trainers for each data format, is that correct?

@davidkoski
Copy link
Collaborator

From your comment I understand we don't really have trainers for the different data formats, and we'll need to create separate trainers for each data format, is that correct?

We don't have something that will take inputs other than straight text and run the inference pass. That is really just a matter of which API is used and the age of the code (VLMs weren't a thing, at least in mlx-swift, when this was built). I think it is just a matter of giving it the right input (LMInput) and everything should work.

But feel free to contribute whatever you can -- we can build this up in pieces!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants