File size: 5,825 Bytes
cb6cbda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import json
import base64
import time
import httpx
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.serialization import load_pem_private_key

# 您的服务账号密钥(请将其保存在安全的地方,不要公开分享)
def create_jwt(client_email, private_key):
    # JWT Header
    header = json.dumps({
        "alg": "RS256",
        "typ": "JWT"
    }).encode()

    # JWT Payload
    now = int(time.time())
    payload = json.dumps({
        "iss": client_email,
        "scope": "https://www.googleapis.com/auth/cloud-platform",
        "aud": "https://oauth2.googleapis.com/token",
        "exp": now + 3600,
        "iat": now
    }).encode()

    # Encode header and payload
    segments = [
        base64.urlsafe_b64encode(header).rstrip(b'='),
        base64.urlsafe_b64encode(payload).rstrip(b'=')
    ]

    # Create signature
    signing_input = b'.'.join(segments)
    private_key = load_pem_private_key(private_key.encode(), password=None)
    signature = private_key.sign(
        signing_input,
        padding.PKCS1v15(),
        hashes.SHA256()
    )

    segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
    return b'.'.join(segments).decode()

def get_access_token(client_email, private_key):
    jwt = create_jwt(client_email, private_key)

    with httpx.Client() as client:
        response = client.post(
            "https://oauth2.googleapis.com/token",
            data={
                "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
                "assertion": jwt
            },
            headers={'Content-Type': "application/x-www-form-urlencoded"}
        )
        response.raise_for_status()
        return response.json()["access_token"]

def ask_stream(prompt, client_email, private_key, project_id, engine):
    payload = {
    "contents": [
        {
            "role": "user",
            "parts": [
                {
                    "text": prompt
                }
            ]
        }
    ],
    "system_instruction": {
        "parts": [
            {
                "text": "You are Gemini, a large language model trained by Google. Respond conversationally"
            }
        ]
    },
    # "safety_settings": [
    #     {
    #         "category": "HARM_CATEGORY_HARASSMENT",
    #         "threshold": "BLOCK_NONE"
    #     },
    #     {
    #         "category": "HARM_CATEGORY_HATE_SPEECH",
    #         "threshold": "BLOCK_NONE"
    #     },
    #     {
    #         "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
    #         "threshold": "BLOCK_NONE"
    #     },
    #     {
    #         "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
    #         "threshold": "BLOCK_NONE"
    #     }
    # ],
    "generationConfig": {
        "temperature": 0.5,
        "max_output_tokens": 256,
        "top_k": 40,
        "top_p": 0.95
    },
    "tools": [
        {
            "function_declarations": [
                {
                    "name": "get_search_results",
                    "description": "Search Google to enhance knowledge.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "prompt": {
                                "type": "string",
                                "description": "The prompt to search."
                            }
                        },
                        "required": [
                            "prompt"
                        ]
                    }
                },
                {
                    "name": "get_url_content",
                    "description": "Get the webpage content of a URL",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "url": {
                                "type": "string",
                                "description": "the URL to request"
                            }
                        },
                        "required": [
                            "url"
                        ]
                    }
                }
            ]
        }
    ],
    "tool_config": {
        "function_calling_config": {
            "mode": "AUTO"
        }
    }
}
    # payload = {
    #     "contents": [
    #         {
    #             "role": "user",
    #             "parts": [
    #                 {
    #                     "text": prompt
    #                 }
    #             ]
    #         },
    #     ],
    #     "generationConfig": {
    #         "temperature": 0.2,
    #         "maxOutputTokens": 256,
    #         "topK": 40,
    #         "topP": 0.95
    #     }
    # }

    access_token = get_access_token(client_email, private_key)
    headers = {
        'Authorization': f"Bearer {access_token}",
        'Content-Type': "application/json"
    }

    MODEL_ID = engine
    PROJECT_ID = project_id
    stream = "generateContent"
    with httpx.Client() as client:
        response = client.post(
            f"https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}",
            json=payload,
            headers=headers,
            timeout=600,
        )
        response.raise_for_status()
        return response.json()

# 使用示例
client_email, private_key, project_id = SERVICE_ACCOUNT_KEY["client_email"], SERVICE_ACCOUNT_KEY["private_key"], SERVICE_ACCOUNT_KEY["project_id"]
engine = "gemini-1.5-pro"
user_input = input("请输入您的问题: ")
result = ask_stream(user_input, client_email, private_key, project_id, engine)
print(json.dumps(result, ensure_ascii=False, indent=2))