FormatChatGenerationDPO¶
Format the output of a combination of a ChatGeneration
+ a preference task such as
UltraFeedback
, for Direct Preference Optimization (DPO) following the standard formatting
from frameworks such as axolotl
or alignment-handbook
.
`FormatChatGenerationDPO` is a `Step` that formats the output of the combination of a `ChatGeneration`
task with a preference `Task` i.e. a task generating `ratings`, so that those are used to rank the
existing generations and provide the `chosen` and `rejected` generations based on the `ratings`.
Note¶
The messages
column should contain at least one message from the user, the generations
column should contain at least two generations, the ratings
column should contain the same
number of ratings as generations.
Input & Output Columns¶
Inputs¶
-
messages (
List[Dict[str, str]]
): The conversation messages. -
generations (
List[str]
): The generations produced by theLLM
. -
generation_models (
List[str]
, optional): The model names used to generate thegenerations
, only available if themodel_name
from theChatGeneration
task/s is combined into a single column named this way, otherwise, it will be ignored. -
ratings (
List[float]
): The ratings for each of thegenerations
, produced by a preference task such asUltraFeedback
.
Outputs¶
-
prompt (
str
): The user message used to generate thegenerations
with theLLM
. -
prompt_id (
str
): TheSHA256
hash of theprompt
. -
chosen (
List[Dict[str, str]]
): Thechosen
generation based on theratings
. -
chosen_model (
str
, optional): The model name used to generate thechosen
generation, if thegeneration_models
are available. -
chosen_rating (
float
): The rating of thechosen
generation. -
rejected (
List[Dict[str, str]]
): Therejected
generation based on theratings
. -
rejected_model (
str
, optional): The model name used to generate therejected
generation, if thegeneration_models
are available. -
rejected_rating (
float
): The rating of therejected
generation.
Examples¶
Format your dataset for DPO fine tuning¶
from distilabel.steps import FormatChatGenerationDPO
format_dpo = FormatChatGenerationDPO()
format_dpo.load()
# NOTE: "generation_models" can be added optionally.
result = next(
format_dpo.process(
[
{
"messages": [{"role": "user", "content": "What's 2+2?"}],
"generations": ["4", "5", "6"],
"ratings": [1, 0, -1],
}
]
)
)
# >>> result
# [
# {
# 'messages': [{'role': 'user', 'content': "What's 2+2?"}],
# 'generations': ['4', '5', '6'],
# 'ratings': [1, 0, -1],
# 'prompt': "What's 2+2?",
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
# 'chosen_rating': 1,
# 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
# 'rejected_rating': -1
# }
# ]