File size: 5,527 Bytes
1afe246
6198c2d
1afe246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6198c2d
 
 
1afe246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8b3651
 
 
 
 
1afe246
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import datetime
from nim_game_env import NimGameEnv
from nim_gpt_functions import plan_move, execute_move

TEMPERATURE_DEFAULT = 0.5
PILES_DEFAULT = [3, 5, 7]
HUMAN_STR = "Human"
AI_STR = "AI"


def reset_game(chat_history, nim_game_env):
    chat_history = []
    nim_game_env = NimGameEnv(PILES_DEFAULT)
    game_state_text, game_state_piles = nim_game_env.reset()
    ascii_art = generate_game_state_ascii_art(game_state_piles, False, 0, "")
    message_str = ""
    return chat_history, chat_history, message_str, ascii_art, nim_game_env


def generate_game_state_ascii_art(piles, done, reward, player):
    ascii_art = "Game Over, " + player + " wins!"
    if not done:
        pile_a = piles[0]
        pile_b = piles[1]
        pile_c = piles[2]
        ascii_art = f"Pile A: {'|' * pile_a} \nPile B: {'|' * pile_b} \nPile C: {'|' * pile_c}"
    return "<pre>" + ascii_art + "</pre>"


def send_chat_msg(inp, chat_history, nim_game_env, temperature, openai_api_key):
    if not openai_api_key or openai_api_key == "":
        warning_msg = "<pre>Please paste your OpenAI API key (see https://beta.openai.com)</pre>"
        return chat_history, chat_history, warning_msg

    if not inp or inp == "":
        warning_msg = "<pre>Please enter a move</pre>"
        return chat_history, chat_history, warning_msg

    inp = inp.strip()
    output = None
    chat_history = chat_history or []

    text_obs, observation, reward, done, info = execute_move(inp, nim_game_env, openai_api_key)
    ascii_art = generate_game_state_ascii_art(observation, done, reward, HUMAN_STR)

    if done:
        if reward == 1:
            output = "Good game!"
            ascii_art = generate_game_state_ascii_art(observation, done, reward, HUMAN_STR)
        else:
            output = text_obs
            ascii_art = generate_game_state_ascii_art(observation, done, reward, AI_STR)
    else:
        output = plan_move(text_obs, temperature, openai_api_key)
        text_obs, observation, reward, done, info = execute_move(output, nim_game_env, openai_api_key)
        ascii_art = generate_game_state_ascii_art(observation, done, reward, AI_STR)

    print("\n==== date/time: " + str(datetime.datetime.now() - datetime.timedelta(hours=5)) + " ====")
    print("inp: " + inp, ", output: ", output, ", observation: ", observation)

    chat_history.append((HUMAN_STR + ": " + inp, AI_STR + ": " + output))
    return chat_history, chat_history, ascii_art


def update_foo(widget, state):
    if widget:
        state = widget
        return state


block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
with block as nim_game:
    temperature_state = gr.State(TEMPERATURE_DEFAULT)
    openai_api_key_state = gr.State()
    history_state = gr.State()
    nim_game_env_state = gr.State(NimGameEnv(PILES_DEFAULT))

    with gr.Row():
        game_state_html = gr.Markdown()
        title = gr.Markdown("""<h3><center>NimGPT-3.5</center></h3>""")
        openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key",
                                            show_label=False, lines=1, type='password')

    chatbot = gr.Chatbot()

    with gr.Row():
        message_tb = gr.Textbox(label="What's your move?",
                                placeholder="I'll take 2 sticks from pile A")
        send_btn = gr.Button(value="Send", variant="secondary").style(full_width=False)

    with gr.Row():
        gr.Examples(
            examples=["Three sticks from the second pile",
                      "From pile C remove 2 sticks"],
            inputs=message_tb
        )
        reset_btn = gr.Button(value="Reset Game", variant="secondary").style(full_width=False)
        temperature_slider = gr.Slider(label="GPT Temperature", value=TEMPERATURE_DEFAULT, minimum=0.0, maximum=1.0,
                                       step=0.1)

    send_btn.click(send_chat_msg, inputs=[message_tb, history_state, nim_game_env_state, temperature_state,
                                          openai_api_key_state],
                   outputs=[chatbot, history_state, game_state_html])
    message_tb.submit(send_chat_msg, inputs=[message_tb, history_state, nim_game_env_state, temperature_state,
                                             openai_api_key_state],
                      outputs=[chatbot, history_state, game_state_html])
    reset_btn.click(reset_game, inputs=[history_state, nim_game_env_state],
                    outputs=[chatbot, history_state, message_tb, game_state_html, nim_game_env_state])
    nim_game.load(reset_game, inputs=[history_state, nim_game_env_state],
                  outputs=[chatbot, history_state, message_tb, game_state_html, nim_game_env_state])

    gr.Markdown("""<center>Each player may remove sticks from a pile on their turn. 
    Player to remove the last stick wins.
    <a href="https://en.wikipedia.org/wiki/Nim" target="new">
    Nim is one of the first-ever electronic computerized games</a>
    </center>""")

    gr.HTML("<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain πŸ¦œοΈπŸ”—</a></center>")

    openai_api_key_textbox.change(update_foo,
                                  inputs=[openai_api_key_textbox, openai_api_key_state],
                                  outputs=[openai_api_key_state])

    temperature_slider.change(update_foo,
                              inputs=[temperature_slider, temperature_state],
                              outputs=[temperature_state])

block.launch(debug=False)