Spaces:
Sleeping
Sleeping
Commit
·
ba986c0
1
Parent(s):
ea5dd54
refactoring for OpenAI and Groq
Browse files
src/app/queries/getSystemPrompt.ts
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { Preset } from "../engine/presets"
|
2 |
+
|
3 |
+
export function getSystemPrompt({
|
4 |
+
preset,
|
5 |
+
// prompt,
|
6 |
+
// existingPanelsTemplate,
|
7 |
+
firstNextOrLast,
|
8 |
+
maxNbPanels,
|
9 |
+
nbPanelsToGenerate,
|
10 |
+
// nbMaxNewTokens,
|
11 |
+
}: {
|
12 |
+
preset: Preset
|
13 |
+
// prompt: string
|
14 |
+
// existingPanelsTemplate: string
|
15 |
+
firstNextOrLast: string
|
16 |
+
maxNbPanels: number
|
17 |
+
nbPanelsToGenerate: number
|
18 |
+
// nbMaxNewTokens: number
|
19 |
+
}) {
|
20 |
+
return [
|
21 |
+
`You are a writer specialized in ${preset.llmPrompt}`,
|
22 |
+
`Please write detailed drawing instructions and short (2-3 sentences long) speech captions for the ${firstNextOrLast} ${nbPanelsToGenerate} panels (out of ${maxNbPanels} in total) of a new story, but keep it open-ended (it will be continued and expanded later). Please make sure each of those ${nbPanelsToGenerate} panels include info about character gender, age, origin, clothes, colors, location, lights, etc. Only generate those ${nbPanelsToGenerate} panels, but take into account the fact the panels are part of a longer story (${maxNbPanels} panels long).`,
|
23 |
+
`Give your response as a VALID JSON array like this: \`Array<{ panel: number; instructions: string; caption: string; }>\`.`,
|
24 |
+
// `Give your response as Markdown bullet points.`,
|
25 |
+
`Be brief in the instructions and narrative captions of those ${nbPanelsToGenerate} panels, don't add your own comments. The captions must be captivating, smart, entertaining. Be straight to the point, and never reply things like "Sure, I can.." etc. Reply using valid JSON!! Important: Write valid JSON!`
|
26 |
+
].filter(item => item).join("\n")
|
27 |
+
}
|
src/app/queries/getUserPrompt.ts
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export function getUserPrompt({
|
2 |
+
prompt,
|
3 |
+
existingPanelsTemplate,
|
4 |
+
}: {
|
5 |
+
prompt: string
|
6 |
+
existingPanelsTemplate: string
|
7 |
+
}) {
|
8 |
+
return `The story is about: ${prompt}.${existingPanelsTemplate}`
|
9 |
+
}
|
src/app/queries/predictNextPanels.ts
CHANGED
@@ -7,6 +7,8 @@ import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
|
|
7 |
import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
|
8 |
import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
|
9 |
import { sleep } from "@/lib/sleep"
|
|
|
|
|
10 |
|
11 |
export const predictNextPanels = async ({
|
12 |
preset,
|
@@ -31,7 +33,6 @@ export const predictNextPanels = async ({
|
|
31 |
? ` To help you, here are the previous panels and their captions (note: if you see an anomaly here eg. no caption or the same description repeated multiple times, do not hesitate to fix the story): ${JSON.stringify(existingPanels, null, 2)}`
|
32 |
: ''
|
33 |
|
34 |
-
|
35 |
const firstNextOrLast =
|
36 |
existingPanels.length === 0
|
37 |
? "first"
|
@@ -39,24 +40,23 @@ export const predictNextPanels = async ({
|
|
39 |
? "last"
|
40 |
: "next"
|
41 |
|
42 |
-
const
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
}
|
57 |
]) + "\n[{"
|
58 |
|
59 |
-
|
60 |
let result = ""
|
61 |
|
62 |
// we don't require a lot of token for our task
|
@@ -66,8 +66,8 @@ export const predictNextPanels = async ({
|
|
66 |
const nbMaxNewTokens = nbPanelsToGenerate * nbTokensPerPanel
|
67 |
|
68 |
try {
|
69 |
-
// console.log(`calling predict
|
70 |
-
result = `${await predict(
|
71 |
console.log("LLM result (1st trial):", result)
|
72 |
if (!result.length) {
|
73 |
throw new Error("empty result on 1st trial!")
|
@@ -78,7 +78,7 @@ export const predictNextPanels = async ({
|
|
78 |
await sleep(2000)
|
79 |
|
80 |
try {
|
81 |
-
result = `${await predict(
|
82 |
console.log("LLM result (2nd trial):", result)
|
83 |
if (!result.length) {
|
84 |
throw new Error("empty result on 2nd trial!")
|
|
|
7 |
import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
|
8 |
import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
|
9 |
import { sleep } from "@/lib/sleep"
|
10 |
+
import { getSystemPrompt } from "./getSystemPrompt"
|
11 |
+
import { getUserPrompt } from "./getUserPrompt"
|
12 |
|
13 |
export const predictNextPanels = async ({
|
14 |
preset,
|
|
|
33 |
? ` To help you, here are the previous panels and their captions (note: if you see an anomaly here eg. no caption or the same description repeated multiple times, do not hesitate to fix the story): ${JSON.stringify(existingPanels, null, 2)}`
|
34 |
: ''
|
35 |
|
|
|
36 |
const firstNextOrLast =
|
37 |
existingPanels.length === 0
|
38 |
? "first"
|
|
|
40 |
? "last"
|
41 |
: "next"
|
42 |
|
43 |
+
const systemPrompt = getSystemPrompt({
|
44 |
+
preset,
|
45 |
+
firstNextOrLast,
|
46 |
+
maxNbPanels,
|
47 |
+
nbPanelsToGenerate,
|
48 |
+
})
|
49 |
+
|
50 |
+
const userPrompt = getUserPrompt({
|
51 |
+
prompt,
|
52 |
+
existingPanelsTemplate,
|
53 |
+
})
|
54 |
+
|
55 |
+
const zephyPrompt = createZephyrPrompt([
|
56 |
+
{ role: "system", content: systemPrompt },
|
57 |
+
{ role: "user", content: userPrompt }
|
58 |
]) + "\n[{"
|
59 |
|
|
|
60 |
let result = ""
|
61 |
|
62 |
// we don't require a lot of token for our task
|
|
|
66 |
const nbMaxNewTokens = nbPanelsToGenerate * nbTokensPerPanel
|
67 |
|
68 |
try {
|
69 |
+
// console.log(`calling predict:`, { systemPrompt, userPrompt, nbMaxNewTokens })
|
70 |
+
result = `${await predict({ systemPrompt, userPrompt, nbMaxNewTokens })}`.trim()
|
71 |
console.log("LLM result (1st trial):", result)
|
72 |
if (!result.length) {
|
73 |
throw new Error("empty result on 1st trial!")
|
|
|
78 |
await sleep(2000)
|
79 |
|
80 |
try {
|
81 |
+
result = `${await predict({ systemPrompt: systemPrompt + " \n ", userPrompt, nbMaxNewTokens })}`.trim()
|
82 |
console.log("LLM result (2nd trial):", result)
|
83 |
if (!result.length) {
|
84 |
throw new Error("empty result on 2nd trial!")
|
src/app/queries/predictWithGroq.ts
CHANGED
@@ -2,7 +2,15 @@
|
|
2 |
|
3 |
import Groq from "groq-sdk"
|
4 |
|
5 |
-
export async function predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
|
7 |
const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
|
8 |
|
@@ -11,7 +19,8 @@ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<s
|
|
11 |
})
|
12 |
|
13 |
const messages: Groq.Chat.Completions.CompletionCreateParams.Message[] = [
|
14 |
-
{ role: "
|
|
|
15 |
]
|
16 |
|
17 |
try {
|
|
|
2 |
|
3 |
import Groq from "groq-sdk"
|
4 |
|
5 |
+
export async function predict({
|
6 |
+
systemPrompt,
|
7 |
+
userPrompt,
|
8 |
+
nbMaxNewTokens,
|
9 |
+
}: {
|
10 |
+
systemPrompt: string
|
11 |
+
userPrompt: string
|
12 |
+
nbMaxNewTokens: number
|
13 |
+
}): Promise<string> {
|
14 |
const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
|
15 |
const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
|
16 |
|
|
|
19 |
})
|
20 |
|
21 |
const messages: Groq.Chat.Completions.CompletionCreateParams.Message[] = [
|
22 |
+
{ role: "system", content: systemPrompt },
|
23 |
+
{ role: "user", content: userPrompt },
|
24 |
]
|
25 |
|
26 |
try {
|
src/app/queries/predictWithHuggingFace.ts
CHANGED
@@ -2,8 +2,17 @@
|
|
2 |
|
3 |
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
|
4 |
import { LLMEngine } from "@/types"
|
|
|
5 |
|
6 |
-
export async function predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
|
8 |
|
9 |
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
|
@@ -46,7 +55,12 @@ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<s
|
|
46 |
try {
|
47 |
for await (const output of api.textGenerationStream({
|
48 |
model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
parameters: {
|
51 |
do_sample: true,
|
52 |
max_new_tokens: nbMaxNewTokens,
|
|
|
2 |
|
3 |
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
|
4 |
import { LLMEngine } from "@/types"
|
5 |
+
import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
|
6 |
|
7 |
+
export async function predict({
|
8 |
+
systemPrompt,
|
9 |
+
userPrompt,
|
10 |
+
nbMaxNewTokens,
|
11 |
+
}: {
|
12 |
+
systemPrompt: string
|
13 |
+
userPrompt: string
|
14 |
+
nbMaxNewTokens: number
|
15 |
+
}): Promise<string> {
|
16 |
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
|
17 |
|
18 |
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
|
|
|
55 |
try {
|
56 |
for await (const output of api.textGenerationStream({
|
57 |
model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
|
58 |
+
|
59 |
+
inputs: createZephyrPrompt([
|
60 |
+
{ role: "system", content: systemPrompt },
|
61 |
+
{ role: "user", content: userPrompt }
|
62 |
+
]) + "\n[{", // <-- important: we force its hand
|
63 |
+
|
64 |
parameters: {
|
65 |
do_sample: true,
|
66 |
max_new_tokens: nbMaxNewTokens,
|
src/app/queries/predictWithOpenAI.ts
CHANGED
@@ -1,9 +1,17 @@
|
|
1 |
"use server"
|
2 |
|
3 |
-
import type {
|
4 |
import OpenAI from "openai"
|
5 |
|
6 |
-
export async function predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
|
8 |
const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
|
9 |
const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
|
@@ -13,8 +21,9 @@ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<s
|
|
13 |
baseURL: openaiApiBaseUrl,
|
14 |
})
|
15 |
|
16 |
-
const messages:
|
17 |
-
{ role: "
|
|
|
18 |
]
|
19 |
|
20 |
try {
|
|
|
1 |
"use server"
|
2 |
|
3 |
+
import type { ChatCompletionMessageParam } from "openai/resources/chat"
|
4 |
import OpenAI from "openai"
|
5 |
|
6 |
+
export async function predict({
|
7 |
+
systemPrompt,
|
8 |
+
userPrompt,
|
9 |
+
nbMaxNewTokens,
|
10 |
+
}: {
|
11 |
+
systemPrompt: string
|
12 |
+
userPrompt: string
|
13 |
+
nbMaxNewTokens: number
|
14 |
+
}): Promise<string> {
|
15 |
const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
|
16 |
const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
|
17 |
const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
|
|
|
21 |
baseURL: openaiApiBaseUrl,
|
22 |
})
|
23 |
|
24 |
+
const messages: ChatCompletionMessageParam[] = [
|
25 |
+
{ role: "system", content: systemPrompt },
|
26 |
+
{ role: "user", content: userPrompt },
|
27 |
]
|
28 |
|
29 |
try {
|