Eladlev commited on
Commit
94b4a05
·
verified ·
1 Parent(s): eaed240

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -247
app.py CHANGED
@@ -1,260 +1,168 @@
1
- """
2
- Entrypoint for Gradio, see https://gradio.app/
3
- """
4
-
5
- import asyncio
6
- import base64
7
- import os
8
- from datetime import datetime
9
- from enum import StrEnum
10
- from functools import partial
11
- from pathlib import Path
12
- from typing import cast, Dict
13
-
14
  import gradio as gr
15
- from anthropic import APIResponse
 
 
 
16
  from anthropic.types import TextBlock
17
  from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
18
- from anthropic.types.tool_use_block import ToolUseBlock
19
-
20
- from computer_use_demo.loop import (
21
- PROVIDER_TO_DEFAULT_MODEL_NAME,
22
- APIProvider,
23
- sampling_loop,
24
- sampling_loop_sync,
25
- )
26
-
27
- from computer_use_demo.tools import ToolResult
28
-
29
-
30
- CONFIG_DIR = Path("~/.anthropic").expanduser()
31
- API_KEY_FILE = CONFIG_DIR / "api_key"
32
-
33
- WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior"
34
-
35
-
36
- class Sender(StrEnum):
37
- USER = "user"
38
- BOT = "assistant"
39
- TOOL = "tool"
40
-
41
-
42
- def setup_state(state):
43
- if "messages" not in state:
44
- state["messages"] = []
45
- if "api_key" not in state:
46
- # Try to load API key from file first, then environment
47
- state["api_key"] = load_from_storage("api_key") or os.getenv("ANTHROPIC_API_KEY", "")
48
- if not state["api_key"]:
49
- print("API key not found. Please set it in the environment or storage.")
50
- if "provider" not in state:
51
- state["provider"] = os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC
52
- if "provider_radio" not in state:
53
- state["provider_radio"] = state["provider"]
54
- if "model" not in state:
55
- _reset_model(state)
56
- if "auth_validated" not in state:
57
- state["auth_validated"] = False
58
- if "responses" not in state:
59
- state["responses"] = {}
60
- if "tools" not in state:
61
- state["tools"] = {}
62
- if "only_n_most_recent_images" not in state:
63
- state["only_n_most_recent_images"] = 10
64
- if "custom_system_prompt" not in state:
65
- state["custom_system_prompt"] = load_from_storage("system_prompt") or ""
66
- # remove if want to use default system prompt
67
- state["custom_system_prompt"] += "\n\nNote that you are operating on a Windows machine, so you should use double click to open a desktop application"
68
- if "hide_images" not in state:
69
- state["hide_images"] = False
70
-
71
-
72
- def _reset_model(state):
73
- state["model"] = PROVIDER_TO_DEFAULT_MODEL_NAME[cast(APIProvider, state["provider"])]
74
-
75
-
76
- async def main(state):
77
- """Render loop for Gradio"""
78
- setup_state(state)
79
- return "Setup completed"
80
-
81
-
82
- def validate_auth(provider: APIProvider, api_key: str | None):
83
- if provider == APIProvider.ANTHROPIC:
84
- if not api_key:
85
- return "Enter your Anthropic API key to continue."
86
- if provider == APIProvider.BEDROCK:
87
- import boto3
88
-
89
- if not boto3.Session().get_credentials():
90
- return "You must have AWS credentials set up to use the Bedrock API."
91
- if provider == APIProvider.VERTEX:
92
- import google.auth
93
- from google.auth.exceptions import DefaultCredentialsError
94
-
95
- if not os.environ.get("CLOUD_ML_REGION"):
96
- return "Set the CLOUD_ML_REGION environment variable to use the Vertex API."
97
- try:
98
- google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
99
- except DefaultCredentialsError:
100
- return "Your google cloud credentials are not set up correctly."
101
-
102
-
103
- def load_from_storage(filename: str) -> str | None:
104
- """Load data from a file in the storage directory."""
105
- try:
106
- file_path = CONFIG_DIR / filename
107
- if file_path.exists():
108
- data = file_path.read_text().strip()
109
- if data:
110
- return data
111
- except Exception as e:
112
- print(f"Debug: Error loading {filename}: {e}")
113
- return None
114
-
115
-
116
- def save_to_storage(filename: str, data: str) -> None:
117
- """Save data to a file in the storage directory."""
118
- try:
119
- CONFIG_DIR.mkdir(parents=True, exist_ok=True)
120
- file_path = CONFIG_DIR / filename
121
- file_path.write_text(data)
122
- # Ensure only user can read/write the file
123
- file_path.chmod(0o600)
124
- except Exception as e:
125
- print(f"Debug: Error saving {filename}: {e}")
126
-
127
-
128
- def _api_response_callback(response: APIResponse[BetaMessage], response_state: dict):
129
- response_id = datetime.now().isoformat()
130
- response_state[response_id] = response
131
-
132
-
133
- def _tool_output_callback(tool_output: ToolResult, tool_id: str, tool_state: dict):
134
- tool_state[tool_id] = tool_output
135
-
136
-
137
- def _render_message(sender: Sender, message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, state):
138
- is_tool_result = not isinstance(message, str) and (
139
- isinstance(message, ToolResult)
140
- or message.__class__.__name__ == "ToolResult"
141
- or message.__class__.__name__ == "CLIResult"
142
- )
143
- if not message or (
144
- is_tool_result
145
- and state["hide_images"]
146
- and not hasattr(message, "error")
147
- and not hasattr(message, "output")
148
- ):
149
- return
150
- if is_tool_result:
151
- message = cast(ToolResult, message)
152
- if message.output:
153
- return message.output
154
- if message.error:
155
- return f"Error: {message.error}"
156
- if message.base64_image and not state["hide_images"]:
157
- return base64.b64decode(message.base64_image)
158
- elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock):
159
- return message.text
160
- elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock):
161
- return f"Tool Use: {message.name}\nInput: {message.input}"
162
- else:
163
- return message
164
- # open new tab, open google sheets inside, then create a new blank spreadsheet
165
-
166
- def process_input(user_input, state):
167
- # Ensure the state is properly initialized
168
- setup_state(state)
169
-
170
- # Append the user input to the messages in the state
171
- state["messages"].append(
172
- {
173
- "role": Sender.USER,
174
- "content": [TextBlock(type="text", text=user_input)],
175
- }
176
  )
 
 
 
 
 
 
177
 
178
- # Run the sampling loop synchronously and yield messages
179
- for message in sampling_loop(state):
180
- yield message
181
 
 
 
 
182
 
183
- def accumulate_messages(*args, **kwargs):
184
- """
185
- Wrapper function to accumulate messages from sampling_loop_sync.
186
- """
187
- accumulated_messages = []
188
-
189
- for message in sampling_loop_sync(*args, **kwargs):
190
- # Check if the message is already in the accumulated messages
191
- if message not in accumulated_messages:
192
- accumulated_messages.append(message)
193
- # Yield the accumulated messages as a list
194
- yield accumulated_messages
195
-
196
-
197
- def sampling_loop(state):
198
- # Ensure the API key is present
199
- if not state.get("api_key"):
200
- raise ValueError("API key is missing. Please set it in the environment or storage.")
201
-
202
- # Call the sampling loop and yield messages
203
- for message in accumulate_messages(
204
- system_prompt_suffix=state["custom_system_prompt"],
205
- model=state["model"],
206
- provider=state["provider"],
207
- messages=state["messages"],
208
- output_callback=partial(_render_message, Sender.BOT, state=state),
209
- tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
210
- api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
211
- api_key=state["api_key"],
212
- only_n_most_recent_images=state["only_n_most_recent_images"],
213
- ):
214
- yield message
215
 
 
216
 
217
- with gr.Blocks() as demo:
218
- state = gr.State({}) # Use Gradio's state management
219
 
220
- gr.Markdown("# Claude Computer Use Demo")
221
 
222
- if not os.getenv("HIDE_WARNING", False):
223
- gr.Markdown(WARNING_TEXT)
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  with gr.Row():
226
- provider = gr.Dropdown(
227
- label="API Provider",
228
- choices=[option.value for option in APIProvider],
229
- value="anthropic",
230
- interactive=True,
231
- )
232
- model = gr.Textbox(label="Model", value="claude-3-5-sonnet-20241022")
233
- api_key = gr.Textbox(
234
- label="Anthropic API Key",
235
- type="password",
236
- value="",
237
- interactive=True,
238
- )
239
- only_n_images = gr.Slider(
240
- label="Only send N most recent images",
241
- minimum=0,
242
- value=10,
243
- interactive=True,
244
- )
245
- custom_prompt = gr.Textbox(
246
- label="Custom System Prompt Suffix",
247
- value="",
248
- interactive=True,
249
- )
250
- hide_images = gr.Checkbox(label="Hide screenshots", value=False)
251
-
252
- api_key.change(fn=lambda key: save_to_storage(API_KEY_FILE, key), inputs=api_key)
253
- chat_input = gr.Textbox(label="Type a message to send to Claude...")
254
- # chat_output = gr.Textbox(label="Chat Output", interactive=False)
255
- chatbot = gr.Chatbot(label="Chatbot History")
256
-
257
- # Pass state as an input to the function
258
- chat_input.submit(process_input, [chat_input, state], chatbot)
259
-
260
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import io
3
+ import os
4
+ from PIL import Image, ImageDraw
5
+ from anthropic import Anthropic
6
  from anthropic.types import TextBlock
7
  from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
8
+ max_tokens = 4096
9
+ import base64
10
+ model = 'claude-3-5-sonnet-20241022'
11
+ system = """<SYSTEM_CAPABILITY>
12
+ * You are utilizing a Windows system with internet access.
13
+ * The current date is Monday, November 18, 2024.
14
+ </SYSTEM_CAPABILITY>"""
15
+
16
+ def save_image_or_get_url(image, filename="processed_image.png"):
17
+ if not os.path.isdir("static"):
18
+ os.mkdir("static")
19
+ filepath = os.path.join("static", filename)
20
+ image.save(filepath)
21
+ return filepath
22
+
23
+ def draw_circle_on_image(image, center, radius=30):
24
+ """
25
+ Draws a circle on the given image using a center point and radius.
26
+
27
+ Parameters:
28
+ image (PIL.Image): The image to draw on.
29
+ center (tuple): A tuple (x, y) representing the center of the circle.
30
+ radius (int): The radius of the circle.
31
+
32
+ Returns:
33
+ PIL.Image: The image with the circle drawn.
34
+ """
35
+ if not isinstance(center, tuple) or len(center) != 2:
36
+ raise ValueError("Center must be a tuple of two values (x, y).")
37
+ if not isinstance(radius, (int, float)) or radius <= 0:
38
+ raise ValueError("Radius must be a positive number.")
39
+
40
+ # Calculate the bounding box for the circle
41
+ bbox = [
42
+ center[0] - radius, center[1] - radius, # Top-left corner
43
+ center[0] + radius, center[1] + radius # Bottom-right corner
44
+ ]
45
+
46
+ # Create a drawing context
47
+ draw = ImageDraw.Draw(image)
48
+
49
+ # Draw the circle
50
+ draw.ellipse(bbox, outline="red", width=15) # Change outline color and width as needed
51
+
52
+ return image
53
+
54
+
55
+ def pil_image_to_base64(pil_image):
56
+ # Save the PIL image to an in-memory buffer as a file-like object
57
+ buffered = io.BytesIO()
58
+ pil_image.save(buffered, format="PNG") # Specify format (e.g., PNG, JPEG)
59
+ buffered.seek(0) # Rewind the buffer to the beginning
60
+
61
+ # Encode the bytes from the buffer to Base64
62
+ image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
63
+ return image_data
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+ # Function to simulate chatbot responses
72
+ def chatbot_response(input_text, image, key, chat_history):
73
+
74
+ if not key:
75
+ return chat_history + [[input_text, "Please enter a valid key."]]
76
+ if image is None:
77
+ return chat_history + [[input_text, "Please upload an image."]]
78
+ new_size = (1512, 982) # For example, resizing to 800x600 pixels
79
+ image = image.resize(new_size)
80
+ api_key =key
81
+ client = Anthropic(api_key=api_key)
82
+
83
+
84
+
85
+ messages = [{'role': 'user', 'content': [TextBlock(text=f'Look at my screenshot, {input_text}', type='text')]},
86
+ {'role': 'assistant', 'content': [BetaTextBlock(
87
+ text="I'll help you check your screen, but first I need to take a screenshot to see what you're looking at.",
88
+ type='text'), BetaToolUseBlock(id='toolu_01PSTVtavFgmx6ctaiSvacCB',
89
+ input={'action': 'screenshot'}, name='computer',
90
+ type='tool_use')]}]
91
+ image_data = pil_image_to_base64(image)
92
+
93
+ tool_res = {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_01PSTVtavFgmx6ctaiSvacCB',
94
+ 'is_error': False,
95
+ 'content': [{'type': 'image',
96
+ 'source': {'type': 'base64', 'media_type': 'image/png',
97
+ 'data': image_data}}]}]}
98
+ messages.append(tool_res)
99
+ params = [{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1512, 'display_height_px': 982,
100
+ 'display_number': None}, {'type': 'bash_20241022', 'name': 'bash'},
101
+ {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}]
102
+ raw_response = client.beta.messages.with_raw_response.create(
103
+ max_tokens=max_tokens,
104
+ messages=messages,
105
+ model=model,
106
+ system=system,
107
+ tools=params,
108
+ betas=["computer-use-2024-10-22"],
109
+ temperature=0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  )
111
+ response = raw_response.parse()
112
+ scale_x = image.width / 1512
113
+ scale_y = image.height / 982
114
+ for r in response.content:
115
+ if hasattr(r, 'text'):
116
+ chat_history = chat_history + [[input_text, r.text]]
117
 
118
+ if hasattr(r, 'input') and 'coordinate' in r.input:
119
+ coordinate = r.input['coordinate']
120
+ new_image = draw_circle_on_image(image, (int(coordinate[0] * scale_x), int(coordinate[1] * scale_y)))
121
 
122
+ # Save the image or encode it as a base64 string if needed
123
+ image_url = save_image_or_get_url(
124
+ new_image) # Define this function to save or generate the URL for the image
125
 
126
+ # Include the image as part of the chat history
127
+ image_html = f'<img src="{image_url}" alt="Processed Image" style="max-width: 100%; max-height: 200px;">'
128
+ chat_history = chat_history + [[None, (image_url,)]]
129
+ return chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # Read the image and encode it in base64
132
 
 
 
133
 
 
134
 
135
+
136
+
137
+ # Simulated response
138
+ response = f"Received input: {input_text}\nKey: {key}\nImage uploaded successfully!"
139
+ return chat_history + [[input_text, response]]
140
+
141
+
142
+ # Create the Gradio interface
143
+ with gr.Blocks() as demo:
144
+ with gr.Row():
145
+ with gr.Column():
146
+ image_input = gr.Image(label="Upload Image", type="pil", interactive=True)
147
+ with gr.Column():
148
+ chatbot = gr.Chatbot(label="Chatbot Interaction", height=400)
149
 
150
  with gr.Row():
151
+ user_input = gr.Textbox(label="Type your message here", placeholder="Enter your message...")
152
+ key_input = gr.Textbox(label="API Key", placeholder="Enter your key...", type="password")
153
+
154
+ # Button to submit
155
+ submit_button = gr.Button("Submit")
156
+
157
+ # Initialize chat history
158
+ chat_history = gr.State(value=[])
159
+
160
+ # Set interactions
161
+ submit_button.click(
162
+ fn=chatbot_response,
163
+ inputs=[user_input, image_input, key_input, chat_history],
164
+ outputs=[chatbot],
165
+ )
166
+
167
+ # Launch the app
168
+ demo.launch()