-
Notifications
You must be signed in to change notification settings - Fork 11.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add llama_batch_ext #11875
base: master
Are you sure you want to change the base?
llama : add llama_batch_ext #11875
Conversation
@ggerganov Would you mind having a look on this initial proposal? Thank you! |
include/llama.h
Outdated
struct llama_batch_ext_token_info { | ||
llama_token token; | ||
llama_pos pos; | ||
int32_t n_seq_id; | ||
llama_seq_id * seq_id; | ||
int8_t logits; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might not be very future-proof. Mixed-modality batches would have tokens, embeddings and tensors mixed together in the same batch. So calling llama_batch_ext_get_token_info(batch, i);
is not always well-defined because it might not be a token at position i
.
Maybe we can postpone this "token_info" API. I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API. This way we can focus only on implementing only the API for creating batches and adding data to them. Later on when we have a better idea of the implementation, we can add a helper API to get info back from the batches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree. Furthermore, this API requires doing a copy, so it won't be the best for performance. It's better to remove this API for now.
I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API.
This kind of logic is currently being used inside llama-server
, not sure it appears on any other examples. I think I can make a thin wrapper for llama_batch_ext
inside the example code. Feel free to tell me if you have a better idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API is removed in 1d6ba97 , a new server_batch
wrapper is added to manage token logits placement in the batch
OK so I've been able to apply this to various example (not all of them). Would be nice if you can have a quick look @ggerganov before I migrate the rest. One thing to note, the loop check over tokens in batch (discussed in #11875 (comment)) is used by both |
The
It seems we rather need something to query the batch, no? How do you imagine I was thinking something like: struct llama_batch_ext_part;
llama_batch_ext_part * part = llama_batch_ext_get_part(batch, i);
if (llama_batch_ext_part_is_token(part)) {
llama_token id = llama_batch_ext_part_get_id(part);
... get token id, sequence id, etc. ...
} But since I'm not 100% about all the details yet related to multi-modal batches, I think it is better to postpone this API for later, and handle the batch information in the user code for now. |
I don't have a clear idea yet, but I'm thinking as a developer using
So when I retrieve back the logits/embeddings, I would imagine that the
Yes we can, and this will be quite similar to my point above. I'm thinking about these 2 options:
|
Hm, yes. The Btw, this makes me wonder if we should actually move the output buffers for logits and the embeddings to be owned by the |
Now that #12181 has been merged, it should be a good time to get this merged too. |
Yes thanks for the heads up, I'll focus on finishing this today & tomorrow |
If the output logits and embeddings are staying are |
Ok so what I've done is to allow "output ID" to be returned from // Add text tokens to the batch
// Return values:
// -1 : not enough space in the batch
// -2 : embd is already set, cannot add text tokens
// otherwise, returns the output ID
LLAMA_API int32_t llama_batch_ext_add_text(
struct llama_batch_ext * batch,
llama_token token,
llama_pos pos,
const llama_seq_id * seq_ids,
size_t n_seq_ids,
bool output);
// Set output (logits/embeddings) for the token in the ith sequence
// If pos == -1, output will be set for the all tokens
// Return values:
// -1 : the token is not in the batch
// otherwise, returns the output ID
LLAMA_API int32_t llama_batch_ext_set_output(
struct llama_batch_ext * batch,
llama_pos pos,
llama_seq_id seq_id);
// Set output (logits/embeddings) for the last added token
// Return values:
// -1 : the batch is empty
// otherwise, returns the output ID
LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch); After this is merged, we can start moving output to float * llama_batch_ext_get_logits(int32_t output_id);
float * llama_batch_ext_get_embeddings(int32_t output_id); |
Off-thread: I tried Cursor to see if its indexing capability can help the migration, but that doesn't work |
Yes, definitely in a separate PR.
Ok, just to make sure we are on the same page: currently the "output ID" is the same as the position of the token in the batch, which works. But the output id can actually be more general than a position - it is a unique identifier associated with a requested output. And internally, the batch would keep a map about what output information should be associated with each output id. Does that make sense? |
Yes that is exactly what I have in mind. For now, the output_id increase by one for each token even if the added token does not have output. But in the future it may make more sense to only increase output_id based on an internal mapping as you said. Ofc that will depend on the implementation, but my point is that output_id will no longer necessarily be the token position in batch |
Ok I have been able to apply this to the rest of example, with some notes:
|
Small note: to make the usage of For the future llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true));
llama_decode_ext(batch);
float * logits = llama_batch_ext_get_logits(-1); |
Edit: let's keep it this way (separated PRs) because it is easier to follow. |
this is actually a revert of ggml-org@cda0e4b
@ggerganov This PR is ready for review now, everything is now working except for:
|
Thanks. I'll take a look at the remaining examples. |
Ref comment: #11292 (comment)
Closes #10381
Migration patterns:
Current status:
llama_batch
from public API --> To be discussedllama-server
works for nowllama_batch
can be migrated to cpp types