|
import type { |
|
ApiData, |
|
ApiInfo, |
|
ClientOptions, |
|
Config, |
|
DuplicateOptions, |
|
EndpointInfo, |
|
JsApiData, |
|
PredictReturn, |
|
SpaceStatus, |
|
Status, |
|
UploadResponse, |
|
client_return, |
|
SubmitIterable, |
|
GradioEvent |
|
} from "./types"; |
|
import { view_api } from "./utils/view_api"; |
|
import { upload_files } from "./utils/upload_files"; |
|
import { upload, FileData } from "./upload"; |
|
import { handle_blob } from "./utils/handle_blob"; |
|
import { post_data } from "./utils/post_data"; |
|
import { predict } from "./utils/predict"; |
|
import { duplicate } from "./utils/duplicate"; |
|
import { submit } from "./utils/submit"; |
|
import { RE_SPACE_NAME, process_endpoint } from "./helpers/api_info"; |
|
import { |
|
map_names_to_ids, |
|
resolve_cookies, |
|
resolve_config, |
|
get_jwt, |
|
parse_and_set_cookies |
|
} from "./helpers/init_helpers"; |
|
import { check_and_wake_space, check_space_status } from "./helpers/spaces"; |
|
import { open_stream, readable_stream, close_stream } from "./utils/stream"; |
|
import { |
|
API_INFO_ERROR_MSG, |
|
CONFIG_ERROR_MSG, |
|
HEARTBEAT_URL, |
|
COMPONENT_SERVER_URL |
|
} from "./constants"; |
|
|
|
export class Client { |
|
app_reference: string; |
|
options: ClientOptions; |
|
|
|
config: Config | undefined; |
|
api_prefix = ""; |
|
api_info: ApiInfo<JsApiData> | undefined; |
|
api_map: Record<string, number> = {}; |
|
session_hash: string = Math.random().toString(36).substring(2); |
|
jwt: string | false = false; |
|
last_status: Record<string, Status["stage"]> = {}; |
|
|
|
private cookies: string | null = null; |
|
|
|
|
|
stream_status = { open: false }; |
|
pending_stream_messages: Record<string, any[][]> = {}; |
|
pending_diff_streams: Record<string, any[][]> = {}; |
|
event_callbacks: Record<string, (data?: unknown) => Promise<void>> = {}; |
|
unclosed_events: Set<string> = new Set(); |
|
heartbeat_event: EventSource | null = null; |
|
abort_controller: AbortController | null = null; |
|
stream_instance: EventSource | null = null; |
|
current_payload: any; |
|
ws_map: Record<string, WebSocket | "failed"> = {}; |
|
|
|
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> { |
|
const headers = new Headers(init?.headers || {}); |
|
if (this && this.cookies) { |
|
headers.append("Cookie", this.cookies); |
|
} |
|
|
|
return fetch(input, { ...init, headers }); |
|
} |
|
|
|
stream(url: URL): EventSource { |
|
const headers = new Headers(); |
|
if (this && this.cookies) { |
|
headers.append("Cookie", this.cookies); |
|
} |
|
|
|
this.abort_controller = new AbortController(); |
|
|
|
this.stream_instance = readable_stream(url.toString(), { |
|
credentials: "include", |
|
headers: headers, |
|
signal: this.abort_controller.signal |
|
}); |
|
|
|
return this.stream_instance; |
|
} |
|
|
|
view_api: () => Promise<ApiInfo<JsApiData>>; |
|
upload_files: ( |
|
root_url: string, |
|
files: (Blob | File)[], |
|
upload_id?: string |
|
) => Promise<UploadResponse>; |
|
upload: ( |
|
file_data: FileData[], |
|
root_url: string, |
|
upload_id?: string, |
|
max_file_size?: number |
|
) => Promise<(FileData | null)[] | null>; |
|
handle_blob: ( |
|
endpoint: string, |
|
data: unknown[], |
|
endpoint_info: EndpointInfo<ApiData | JsApiData> |
|
) => Promise<unknown[]>; |
|
post_data: ( |
|
url: string, |
|
body: unknown, |
|
additional_headers?: any |
|
) => Promise<unknown[]>; |
|
submit: ( |
|
endpoint: string | number, |
|
data: unknown[] | Record<string, unknown> | undefined, |
|
event_data?: unknown, |
|
trigger_id?: number | null, |
|
all_events?: boolean |
|
) => SubmitIterable<GradioEvent>; |
|
predict: ( |
|
endpoint: string | number, |
|
data: unknown[] | Record<string, unknown> | undefined, |
|
event_data?: unknown |
|
) => Promise<PredictReturn>; |
|
open_stream: () => Promise<void>; |
|
private resolve_config: (endpoint: string) => Promise<Config | undefined>; |
|
private resolve_cookies: () => Promise<void>; |
|
constructor( |
|
app_reference: string, |
|
options: ClientOptions = { events: ["data"] } |
|
) { |
|
this.app_reference = app_reference; |
|
if (!options.events) { |
|
options.events = ["data"]; |
|
} |
|
|
|
this.options = options; |
|
this.current_payload = {}; |
|
this.view_api = view_api.bind(this); |
|
this.upload_files = upload_files.bind(this); |
|
this.handle_blob = handle_blob.bind(this); |
|
this.post_data = post_data.bind(this); |
|
this.submit = submit.bind(this); |
|
this.predict = predict.bind(this); |
|
this.open_stream = open_stream.bind(this); |
|
this.resolve_config = resolve_config.bind(this); |
|
this.resolve_cookies = resolve_cookies.bind(this); |
|
this.upload = upload.bind(this); |
|
this.fetch = this.fetch.bind(this); |
|
this.handle_space_success = this.handle_space_success.bind(this); |
|
this.stream = this.stream.bind(this); |
|
} |
|
|
|
private async init(): Promise<void> { |
|
if ( |
|
(typeof window === "undefined" || !("WebSocket" in window)) && |
|
!global.WebSocket |
|
) { |
|
const ws = await import("ws"); |
|
global.WebSocket = ws.WebSocket as unknown as typeof WebSocket; |
|
} |
|
|
|
if (this.options.auth) { |
|
await this.resolve_cookies(); |
|
} |
|
|
|
await this._resolve_config().then(({ config }) => |
|
this._resolve_hearbeat(config) |
|
); |
|
|
|
this.api_info = await this.view_api(); |
|
this.api_map = map_names_to_ids(this.config?.dependencies || []); |
|
} |
|
|
|
async _resolve_hearbeat(_config: Config): Promise<void> { |
|
if (_config) { |
|
this.config = _config; |
|
this.api_prefix = _config.api_prefix || ""; |
|
|
|
if (this.config && this.config.connect_heartbeat) { |
|
if (this.config.space_id && this.options.hf_token) { |
|
this.jwt = await get_jwt( |
|
this.config.space_id, |
|
this.options.hf_token, |
|
this.cookies |
|
); |
|
} |
|
} |
|
} |
|
|
|
if (_config.space_id && this.options.hf_token) { |
|
this.jwt = await get_jwt(_config.space_id, this.options.hf_token); |
|
} |
|
|
|
if (this.config && this.config.connect_heartbeat) { |
|
|
|
const heartbeat_url = new URL( |
|
`${this.config.root}${this.api_prefix}/${HEARTBEAT_URL}/${this.session_hash}` |
|
); |
|
|
|
|
|
if (this.jwt) { |
|
heartbeat_url.searchParams.set("__sign", this.jwt); |
|
} |
|
|
|
|
|
if (!this.heartbeat_event) { |
|
this.heartbeat_event = this.stream(heartbeat_url); |
|
} |
|
} |
|
} |
|
|
|
static async connect( |
|
app_reference: string, |
|
options: ClientOptions = { |
|
events: ["data"] |
|
} |
|
): Promise<Client> { |
|
const client = new this(app_reference, options); |
|
await client.init(); |
|
return client; |
|
} |
|
|
|
close(): void { |
|
close_stream(this.stream_status, this.abort_controller); |
|
} |
|
|
|
set_current_payload(payload: any): void { |
|
this.current_payload = payload; |
|
} |
|
|
|
static async duplicate( |
|
app_reference: string, |
|
options: DuplicateOptions = { |
|
events: ["data"] |
|
} |
|
): Promise<Client> { |
|
return duplicate(app_reference, options); |
|
} |
|
|
|
private async _resolve_config(): Promise<any> { |
|
const { http_protocol, host, space_id } = await process_endpoint( |
|
this.app_reference, |
|
this.options.hf_token |
|
); |
|
|
|
const { status_callback } = this.options; |
|
|
|
if (space_id && status_callback) { |
|
await check_and_wake_space(space_id, status_callback); |
|
} |
|
|
|
let config: Config | undefined; |
|
|
|
try { |
|
config = await this.resolve_config(`${http_protocol}//${host}`); |
|
|
|
if (!config) { |
|
throw new Error(CONFIG_ERROR_MSG); |
|
} |
|
|
|
return this.config_success(config); |
|
} catch (e: any) { |
|
if (space_id && status_callback) { |
|
check_space_status( |
|
space_id, |
|
RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain", |
|
this.handle_space_success |
|
); |
|
} else { |
|
if (status_callback) |
|
status_callback({ |
|
status: "error", |
|
message: "Could not load this space.", |
|
load_status: "error", |
|
detail: "NOT_FOUND" |
|
}); |
|
throw Error(e); |
|
} |
|
} |
|
} |
|
|
|
private async config_success( |
|
_config: Config |
|
): Promise<Config | client_return> { |
|
this.config = _config; |
|
this.api_prefix = _config.api_prefix || ""; |
|
|
|
if (typeof window !== "undefined" && typeof document !== "undefined") { |
|
if (window.location.protocol === "https:") { |
|
this.config.root = this.config.root.replace("http://", "https://"); |
|
} |
|
} |
|
|
|
if (this.config.auth_required) { |
|
return this.prepare_return_obj(); |
|
} |
|
|
|
try { |
|
this.api_info = await this.view_api(); |
|
} catch (e) { |
|
console.error(API_INFO_ERROR_MSG + (e as Error).message); |
|
} |
|
|
|
return this.prepare_return_obj(); |
|
} |
|
|
|
async handle_space_success(status: SpaceStatus): Promise<Config | void> { |
|
if (!this) { |
|
throw new Error(CONFIG_ERROR_MSG); |
|
} |
|
const { status_callback } = this.options; |
|
if (status_callback) status_callback(status); |
|
if (status.status === "running") { |
|
try { |
|
this.config = await this._resolve_config(); |
|
this.api_prefix = this?.config?.api_prefix || ""; |
|
|
|
if (!this.config) { |
|
throw new Error(CONFIG_ERROR_MSG); |
|
} |
|
|
|
const _config = await this.config_success(this.config); |
|
|
|
return _config as Config; |
|
} catch (e) { |
|
if (status_callback) { |
|
status_callback({ |
|
status: "error", |
|
message: "Could not load this space.", |
|
load_status: "error", |
|
detail: "NOT_FOUND" |
|
}); |
|
} |
|
throw e; |
|
} |
|
} |
|
} |
|
|
|
public async component_server( |
|
component_id: number, |
|
fn_name: string, |
|
data: unknown[] | { binary: boolean; data: Record<string, any> } |
|
): Promise<unknown> { |
|
if (!this.config) { |
|
throw new Error(CONFIG_ERROR_MSG); |
|
} |
|
|
|
const headers: { |
|
Authorization?: string; |
|
"Content-Type"?: "application/json"; |
|
} = {}; |
|
|
|
const { hf_token } = this.options; |
|
const { session_hash } = this; |
|
|
|
if (hf_token) { |
|
headers.Authorization = `Bearer ${this.options.hf_token}`; |
|
} |
|
|
|
let root_url: string; |
|
let component = this.config.components.find( |
|
(comp) => comp.id === component_id |
|
); |
|
if (component?.props?.root_url) { |
|
root_url = component.props.root_url; |
|
} else { |
|
root_url = this.config.root; |
|
} |
|
|
|
let body: FormData | string; |
|
|
|
if ("binary" in data) { |
|
body = new FormData(); |
|
for (const key in data.data) { |
|
if (key === "binary") continue; |
|
body.append(key, data.data[key]); |
|
} |
|
body.set("component_id", component_id.toString()); |
|
body.set("fn_name", fn_name); |
|
body.set("session_hash", session_hash); |
|
} else { |
|
body = JSON.stringify({ |
|
data: data, |
|
component_id, |
|
fn_name, |
|
session_hash |
|
}); |
|
|
|
headers["Content-Type"] = "application/json"; |
|
} |
|
|
|
if (hf_token) { |
|
headers.Authorization = `Bearer ${hf_token}`; |
|
} |
|
|
|
try { |
|
const response = await this.fetch( |
|
`${root_url}${this.api_prefix}/${COMPONENT_SERVER_URL}/`, |
|
{ |
|
method: "POST", |
|
body: body, |
|
headers, |
|
credentials: "include" |
|
} |
|
); |
|
|
|
if (!response.ok) { |
|
throw new Error( |
|
"Could not connect to component server: " + response.statusText |
|
); |
|
} |
|
|
|
const output = await response.json(); |
|
return output; |
|
} catch (e) { |
|
console.warn(e); |
|
} |
|
} |
|
|
|
public set_cookies(raw_cookies: string): void { |
|
this.cookies = parse_and_set_cookies(raw_cookies).join("; "); |
|
} |
|
|
|
private prepare_return_obj(): client_return { |
|
return { |
|
config: this.config, |
|
predict: this.predict, |
|
submit: this.submit, |
|
view_api: this.view_api, |
|
component_server: this.component_server |
|
}; |
|
} |
|
|
|
private async connect_ws(url: string): Promise<void> { |
|
return new Promise((resolve, reject) => { |
|
let ws; |
|
try { |
|
ws = new WebSocket(url); |
|
} catch (e) { |
|
this.ws_map[url] = "failed"; |
|
return; |
|
} |
|
|
|
ws.onopen = () => { |
|
resolve(); |
|
}; |
|
|
|
ws.onerror = (error) => { |
|
console.error("WebSocket error:", error); |
|
this.close_ws(url); |
|
this.ws_map[url] = "failed"; |
|
resolve(); |
|
}; |
|
|
|
ws.onclose = () => { |
|
delete this.ws_map[url]; |
|
this.ws_map[url] = "failed"; |
|
}; |
|
|
|
ws.onmessage = (event) => {}; |
|
this.ws_map[url] = ws; |
|
}); |
|
} |
|
|
|
async send_ws_message(url: string, data: any): Promise<void> { |
|
|
|
if (!(url in this.ws_map)) { |
|
await this.connect_ws(url); |
|
} |
|
const ws = this.ws_map[url]; |
|
if (ws instanceof WebSocket) { |
|
ws.send(JSON.stringify(data)); |
|
} else { |
|
this.post_data(url, data); |
|
} |
|
} |
|
|
|
async close_ws(url: string): Promise<void> { |
|
if (url in this.ws_map) { |
|
const ws = this.ws_map[url]; |
|
if (ws instanceof WebSocket) { |
|
ws.close(); |
|
delete this.ws_map[url]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export async function client( |
|
app_reference: string, |
|
options: ClientOptions = { |
|
events: ["data"] |
|
} |
|
): Promise<Client> { |
|
return await Client.connect(app_reference, options); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export async function duplicate_space( |
|
app_reference: string, |
|
options: DuplicateOptions |
|
): Promise<Client> { |
|
return await Client.duplicate(app_reference, options); |
|
} |
|
|
|
export type ClientInstance = Client; |
|
|