File size: 1,981 Bytes
0f43f8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
This module provides functionality to print a conversation with messages
colored according to the role of the speaker.
"""

import json

import typer

from termcolor import colored

app = typer.Typer()


def pretty_print_conversation(messages):
    """
    Prints a conversation with messages formatted and colored by role.

    Parameters
    ----------
    messages : list
        A list of message dictionaries, each containing 'role', 'name', and 'content' keys.

    """

    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    formatted_messages = []
    for message in messages:
        if message["role"] == "function":
            formatted_messages.append(
                f"function ({message['name']}): {message['content']}\n"
            )
        else:
            assistant_content = (
                message["function_call"]
                if message.get("function_call")
                else message["content"]
            )
            role_to_message = {
                "system": f"system: {message['content']}\n",
                "user": f"user: {message['content']}\n",
                "assistant": f"assistant: {assistant_content}\n",
            }
            formatted_messages.append(role_to_message[message["role"]])

    for formatted_message in formatted_messages:
        role = messages[formatted_messages.index(formatted_message)]["role"]
        color = role_to_color[role]
        print(colored(formatted_message, color))


@app.command()
def main(
    messages_path: str,
):
    """
    Main function that loads messages from a JSON file and prints them using pretty formatting.

    Parameters
    ----------
    messages_path : str
        The file path to the JSON file containing the messages.

    """
    with open(messages_path) as f:
        messages = json.load(f)

    pretty_print_conversation(messages)


if __name__ == "__main__":
    app()