-
Notifications
You must be signed in to change notification settings - Fork 547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tokenized example debugging during training #1520
base: main
Are you sure you want to change the base?
Add tokenized example debugging during training #1520
Conversation
Related to oumi-ai#1369 Add functionality to log tokenized examples for debugging during training. * Add `log_tokenized_example` function in `src/oumi/builders/collators.py` to log raw, formatted, tokenized examples, and model input. * Modify `build_data_collator` in `src/oumi/builders/collators.py` to accept a `debug` parameter and pass it to the collators. * Update `build_collator_from_config` in `src/oumi/builders/collators.py` to pass the `debug` parameter from the config to `build_data_collator`. * Add a new command-line option `--debug-tokenized-example` in `src/oumi/cli/train.py` to enable logging of tokenized examples during training. * Pass the `debug` flag to the training configuration in `src/oumi/cli/train.py`. * Modify `TextCollatorWithPadding` in `src/oumi/core/collators/text_collator_with_padding.py` to accept a `debug` parameter and call `log_tokenized_example` in the `__call__` method if `debug` is set to `True`. * Modify `TextCompletionsCollatorWithPadding` in `src/oumi/core/collators/text_completions_collator_with_padding.py` to accept a `debug` parameter and call `log_tokenized_example` in the `__call__` method if `debug` is set to `True`.
Looping in Jeremy, who logged the issue, and Panos, who's the current on call. |
@vishwamartur Please auto-format your changes: https://github.com/oumi-ai/oumi/blob/main/CONTRIBUTING.md#pull-request-pr-guidelines |
@@ -223,6 +227,15 @@ def __call__(self, batch) -> dict[str, Any]: | |||
if labels_on: | |||
combined_batch[_LABELS_KEY] = collated_text_inputs[_LABELS_KEY] | |||
|
|||
if self._debug: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More of a design concern: having a debug example logged in each batch seems excessive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, let's only log this for the first sample of the first batch.
@@ -223,6 +227,15 @@ def __call__(self, batch) -> dict[str, Any]: | |||
if labels_on: | |||
combined_batch[_LABELS_KEY] = collated_text_inputs[_LABELS_KEY] | |||
|
|||
if self._debug: | |||
raw_example = batch[0] | |||
formatted_example = tokenizer.apply_chat_template(raw_example, tokenize=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it obvious at this point that the tokenizer has a chat_template under all use cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, there's no guarantee this collator has a chat template.
Also, we've already tokenized the example, so a better choice would be to simply take the input ids of the combined_batch's first element and tokenizer.decode
the input ids into their respective strings.
It would also be useful to log all the elements of the first batch (i.e. the masks and labels as well)
We don't necessarily need to decode the labels or mask though, just the input ids.
@@ -64,4 +69,13 @@ def __call__(self, batch) -> dict[str, Any]: | |||
# Collate batch prompts. | |||
collated_text_inputs = self._collate(batch) | |||
|
|||
if self._debug: | |||
raw_example = batch[0] | |||
formatted_example = self._tokenizer.apply_chat_template(raw_example, tokenize=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @vishwamartur ! Thank you very much for your contribution! I left some quick feedback. Obviously, also the code needs fixing at the minimum to pass the currently failing checks.
More broadly (+ @jgreer013 ) who opened the issue:
- I think logging an example within each batch formed is excessive. What do you think?
- It would be great to create a slightly more generic function wrapper that is called once and acts based on collator/tokenizer (e.g., does it include a chat template, is it vision+text,...) when the data is being prepared.
- Minor: the explicit debug variable (in each/any) collator could it be used in the future for more debugging purposes? If not, and if it is kept, perhaps make the name more explicit to the actions (log_example...).
@@ -223,6 +227,15 @@ def __call__(self, batch) -> dict[str, Any]: | |||
if labels_on: | |||
combined_batch[_LABELS_KEY] = collated_text_inputs[_LABELS_KEY] | |||
|
|||
if self._debug: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, let's only log this for the first sample of the first batch.
@@ -223,6 +227,15 @@ def __call__(self, batch) -> dict[str, Any]: | |||
if labels_on: | |||
combined_batch[_LABELS_KEY] = collated_text_inputs[_LABELS_KEY] | |||
|
|||
if self._debug: | |||
raw_example = batch[0] | |||
formatted_example = tokenizer.apply_chat_template(raw_example, tokenize=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, there's no guarantee this collator has a chat template.
Also, we've already tokenized the example, so a better choice would be to simply take the input ids of the combined_batch's first element and tokenizer.decode
the input ids into their respective strings.
It would also be useful to log all the elements of the first batch (i.e. the masks and labels as well)
We don't necessarily need to decode the labels or mask though, just the input ids.
@@ -64,4 +69,13 @@ def __call__(self, batch) -> dict[str, Any]: | |||
# Collate batch prompts. | |||
collated_text_inputs = self._collate(batch) | |||
|
|||
if self._debug: | |||
raw_example = batch[0] | |||
formatted_example = self._tokenizer.apply_chat_template(raw_example, tokenize=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Related to #1369
Add functionality to log tokenized examples for debugging during training.
log_tokenized_example
function insrc/oumi/builders/collators.py
to log raw, formatted, tokenized examples, and model input.build_data_collator
insrc/oumi/builders/collators.py
to accept adebug
parameter and pass it to the collators.build_collator_from_config
insrc/oumi/builders/collators.py
to pass thedebug
parameter from the config tobuild_data_collator
.--debug-tokenized-example
insrc/oumi/cli/train.py
to enable logging of tokenized examples during training.debug
flag to the training configuration insrc/oumi/cli/train.py
.TextCollatorWithPadding
insrc/oumi/core/collators/text_collator_with_padding.py
to accept adebug
parameter and calllog_tokenized_example
in the__call__
method ifdebug
is set toTrue
.TextCompletionsCollatorWithPadding
insrc/oumi/core/collators/text_completions_collator_with_padding.py
to accept adebug
parameter and calllog_tokenized_example
in the__call__
method ifdebug
is set toTrue
.