Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add llama_batch_ext #11875

Open
wants to merge 32 commits into
base: master
Choose a base branch
from

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented Feb 14, 2025

Ref comment: #11292 (comment)

Closes #10381

Migration patterns:

llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// becomes:
llama_batch_ext_ptr batch = llama_batch_ext_ptr(llama_batch_ext_init(n_kv_max, 1));


common_batch_add(batch, tokens[i], pos, { 0 }, false);
// becomes:
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens[i], pos, &seq_id, 1, false);


llama_decode(lctx, llama_batch_get_one(tokens.data(), std::min(tokens.size(), (size_t) params.n_batch)));
// becomes:
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
llama_decode_ext(lctx, batch.get());


llama_decode(ctx, batch);
// becomes:
llama_decode_ext(ctx, batch.get());

Current status:

  • This PR currently contains the first proposal of public API that allows hiding llama_batch from public API --> To be discussed
  • Only llama-server works for now
  • TODO: the members of llama_batch can be migrated to cpp types

@ngxson
Copy link
Collaborator Author

ngxson commented Feb 14, 2025

@ggerganov Would you mind having a look on this initial proposal? Thank you!

include/llama.h Outdated
Comment on lines 266 to 272
struct llama_batch_ext_token_info {
llama_token token;
llama_pos pos;
int32_t n_seq_id;
llama_seq_id * seq_id;
int8_t logits;
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be very future-proof. Mixed-modality batches would have tokens, embeddings and tensors mixed together in the same batch. So calling llama_batch_ext_get_token_info(batch, i); is not always well-defined because it might not be a token at position i.

Maybe we can postpone this "token_info" API. I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API. This way we can focus only on implementing only the API for creating batches and adding data to them. Later on when we have a better idea of the implementation, we can add a helper API to get info back from the batches.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. Furthermore, this API requires doing a copy, so it won't be the best for performance. It's better to remove this API for now.

I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API.

This kind of logic is currently being used inside llama-server, not sure it appears on any other examples. I think I can make a thin wrapper for llama_batch_ext inside the example code. Feel free to tell me if you have a better idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is removed in 1d6ba97 , a new server_batch wrapper is added to manage token logits placement in the batch

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 1, 2025

OK so I've been able to apply this to various example (not all of them). Would be nice if you can have a quick look @ggerganov before I migrate the rest.

One thing to note, the loop check over tokens in batch (discussed in #11875 (comment)) is used by both server.cpp and embeddings.cpp, so my solution was to create a thin wrapper called common_batch. Looks a bit messy for now, so I'm wondering if in the future we can have a llama_get_embeddings_ext or something that can make this easier.

@ggerganov
Copy link
Member

The common_batch is ok for now.

Looks a bit messy for now, so I'm wondering if in the future we can have a llama_get_embeddings_ext or something that can make this easier.

It seems we rather need something to query the batch, no? How do you imagine llama_get_embeddings_ext to work?

I was thinking something like:

struct llama_batch_ext_part;

llama_batch_ext_part * part = llama_batch_ext_get_part(batch, i);
if (llama_batch_ext_part_is_token(part)) {
    llama_token id = llama_batch_ext_part_get_id(part);
    ... get token id, sequence id, etc. ...
}

But since I'm not 100% about all the details yet related to multi-modal batches, I think it is better to postpone this API for later, and handle the batch information in the user code for now.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 3, 2025

How do you imagine llama_get_embeddings_ext to work?

I don't have a clear idea yet, but I'm thinking as a developer using libllama in the their program: Whenever I add a token to the batch, in case of text token I need to know:

  • The token (token ID in case of text)
  • The pos
  • The seq_id

So when I retrieve back the logits/embeddings, I would imagine that the get_embeddings function will have one of these 2 signatures:

  • get_embeddings(seq_id) ==> we already had llama_get_embeddings_seq
  • get_embeddings(seq_id, pos) ==> we currently need to read back the tokens from batch

It seems we rather need something to query the batch, no?

Yes we can, and this will be quite similar to my point above. I'm thinking about these 2 options:

  • Having something like llama_batch_ext_query(seq_id, pos) that returns the output_id of the token. This can then be used with llama_get_embeddings_ith(output_id)
  • Or, explicitly has llama_batch_ext_set_output(...) that returns the output_id. That means the logits param will be removed from llama_batch_ext_add_text
  • (Edit) Or, another option, llama_batch_ext_add_text can return the output_id if logits is set to true

@ggerganov
Copy link
Member

Yes we can, and this will be quite similar to my point above. I'm thinking about these 2 options:

Having something like llama_batch_ext_query(seq_id, pos) that returns the output_id of the token. This can then be used with llama_get_embeddings_ith(output_id)
Or, explicitly has llama_batch_ext_set_output(...) that returns the output_id. That means the logits param will be removed from llama_batch_ext_add_text
(Edit) Or, another option, llama_batch_ext_add_text can return the output_id if logits is set to true

Hm, yes. The llama_batch_ext_set_output() idea sounds good.

Btw, this makes me wonder if we should actually move the output buffers for logits and the embeddings to be owned by the llama_batch_ext (currently these buffers are owned by the llama_context and are shared by all batches)?

@ggerganov
Copy link
Member

Now that #12181 has been merged, it should be a good time to get this merged too.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Yes thanks for the heads up, I'll focus on finishing this today & tomorrow

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Btw, this makes me wonder if we should actually move the output buffers for logits and the embeddings to be owned by the llama_batch_ext (currently these buffers are owned by the llama_context and are shared by all batches)?

If the output logits and embeddings are staying are float * or std::vector<float> then yes, I think it will be better to move them to llama_batch_ext (and can be done in a follow-up PR)

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Ok so what I've done is to allow "output ID" to be returned from llama_batch_ext_add/set...

    // Add text tokens to the batch
    // Return values:
    // -1 : not enough space in the batch
    // -2 : embd is already set, cannot add text tokens
    // otherwise, returns the output ID
    LLAMA_API int32_t llama_batch_ext_add_text(
        struct llama_batch_ext * batch,
                   llama_token   token,
                     llama_pos   pos,
            const llama_seq_id * seq_ids,
                        size_t   n_seq_ids,
                          bool   output);

    // Set output (logits/embeddings) for the token in the ith sequence
    // If pos == -1, output will be set for the all tokens
    // Return values:
    // -1 : the token is not in the batch
    // otherwise, returns the output ID
    LLAMA_API int32_t llama_batch_ext_set_output(
        struct llama_batch_ext * batch,
                     llama_pos   pos,
                  llama_seq_id   seq_id);

    // Set output (logits/embeddings) for the last added token
    // Return values:
    // -1 : the batch is empty
    // otherwise, returns the output ID
    LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch);

After this is merged, we can start moving output to llama_batch_ext (we can finally call it "output" now instead of "logits"), by implementing these API:

float * llama_batch_ext_get_logits(int32_t output_id);
float * llama_batch_ext_get_embeddings(int32_t output_id);

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Off-thread: I tried Cursor to see if its indexing capability can help the migration, but that doesn't work

@ggerganov
Copy link
Member

If the output logits and embeddings are staying are float * or std::vector then yes, I think it will be better to move them to llama_batch_ext (and can be done in a follow-up PR)

Yes, definitely in a separate PR.

Ok so what I've done is to allow "output ID" to be returned from llama_batch_ext_add/set...

Ok, just to make sure we are on the same page: currently the "output ID" is the same as the position of the token in the batch, which works. But the output id can actually be more general than a position - it is a unique identifier associated with a requested output. And internally, the batch would keep a map about what output information should be associated with each output id. Does that make sense?

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

But the output id can actually be more general than a position - it is a unique identifier associated with a requested output. And internally, the batch would keep a map about what output information should be associated with each output id. Does that make sense?

Yes that is exactly what I have in mind. For now, the output_id increase by one for each token even if the added token does not have output.

But in the future it may make more sense to only increase output_id based on an internal mapping as you said. Ofc that will depend on the implementation, but my point is that output_id will no longer necessarily be the token position in batch

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Ok I have been able to apply this to the rest of example, with some notes:

  • qwen2vl-cli set custom position for each token, so I will need to add support for this case in llama_batch_ext_init_from_embd (for now, I remove the deprecation llama_decode just to make it work)
  • llama-android is still not yet migrated --> no idea how to fix it, keep as-is
  • I'm not sure if swiftui also need to be updated --> no idea how to fix it, keep as-is
  • speculative.cpp in the examples is broken for now, but I have too little knowledge to fix it. Could you please help? @ggerganov (I temporary add #ifdef 0 to make it compile)
  • I should also find an AI "reviewer" that works with cpp, I feel like this kind of task (repetitive) is perfectly suitable for AI ; unfortunately @copilot does not work as we already known ; Update: Code Rabbit handles quite well the large context of this PR

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Small note: to make the usage of llama_batch_ext_init_from_text easier, I added an argument bool output_last that internally enable output for last token in the input.

For the future llama_batch_ext_get_logits API, we may need to allow output_id == -1 to get logits of last input (used in conjunction with bool output_last above). This can make the usage a bit more simple for basic text generation use case:

llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true));
llama_decode_ext(batch);
float * logits = llama_batch_ext_get_logits(-1);

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

@coderabbitai review Doesn't work right now, but I created a parallel PR on my namespace and it works: ngxson#14

Edit: let's keep it this way (separated PRs) because it is easier to follow.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 14, 2025

@ggerganov This PR is ready for review now, everything is now working except for:

  • speculative.cpp which I use #if 0 is disable the code. I don't have enough knowledge for fix it, so can you please give a try?
  • android and swiftui, I also don't have enough knowledge to code on swift and Java JNI, so better to leave them as-is for now and maybe wait for the community to fix it later on.

@ggerganov
Copy link
Member

Thanks. I'll take a look at the remaining examples.

@ngxson ngxson changed the title llama : private llama_batch llama : add llama_batch_ext Mar 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor: Allow adding both tokens and embeddings to llama_batch
2 participants