jbilcke-hf HF staff commited on
Commit
ba986c0
·
1 Parent(s): ea5dd54

refactoring for OpenAI and Groq

Browse files
src/app/queries/getSystemPrompt.ts ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Preset } from "../engine/presets"
2
+
3
+ export function getSystemPrompt({
4
+ preset,
5
+ // prompt,
6
+ // existingPanelsTemplate,
7
+ firstNextOrLast,
8
+ maxNbPanels,
9
+ nbPanelsToGenerate,
10
+ // nbMaxNewTokens,
11
+ }: {
12
+ preset: Preset
13
+ // prompt: string
14
+ // existingPanelsTemplate: string
15
+ firstNextOrLast: string
16
+ maxNbPanels: number
17
+ nbPanelsToGenerate: number
18
+ // nbMaxNewTokens: number
19
+ }) {
20
+ return [
21
+ `You are a writer specialized in ${preset.llmPrompt}`,
22
+ `Please write detailed drawing instructions and short (2-3 sentences long) speech captions for the ${firstNextOrLast} ${nbPanelsToGenerate} panels (out of ${maxNbPanels} in total) of a new story, but keep it open-ended (it will be continued and expanded later). Please make sure each of those ${nbPanelsToGenerate} panels include info about character gender, age, origin, clothes, colors, location, lights, etc. Only generate those ${nbPanelsToGenerate} panels, but take into account the fact the panels are part of a longer story (${maxNbPanels} panels long).`,
23
+ `Give your response as a VALID JSON array like this: \`Array<{ panel: number; instructions: string; caption: string; }>\`.`,
24
+ // `Give your response as Markdown bullet points.`,
25
+ `Be brief in the instructions and narrative captions of those ${nbPanelsToGenerate} panels, don't add your own comments. The captions must be captivating, smart, entertaining. Be straight to the point, and never reply things like "Sure, I can.." etc. Reply using valid JSON!! Important: Write valid JSON!`
26
+ ].filter(item => item).join("\n")
27
+ }
src/app/queries/getUserPrompt.ts ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ export function getUserPrompt({
2
+ prompt,
3
+ existingPanelsTemplate,
4
+ }: {
5
+ prompt: string
6
+ existingPanelsTemplate: string
7
+ }) {
8
+ return `The story is about: ${prompt}.${existingPanelsTemplate}`
9
+ }
src/app/queries/predictNextPanels.ts CHANGED
@@ -7,6 +7,8 @@ import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
7
  import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
8
  import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
9
  import { sleep } from "@/lib/sleep"
 
 
10
 
11
  export const predictNextPanels = async ({
12
  preset,
@@ -31,7 +33,6 @@ export const predictNextPanels = async ({
31
  ? ` To help you, here are the previous panels and their captions (note: if you see an anomaly here eg. no caption or the same description repeated multiple times, do not hesitate to fix the story): ${JSON.stringify(existingPanels, null, 2)}`
32
  : ''
33
 
34
-
35
  const firstNextOrLast =
36
  existingPanels.length === 0
37
  ? "first"
@@ -39,24 +40,23 @@ export const predictNextPanels = async ({
39
  ? "last"
40
  : "next"
41
 
42
- const query = createZephyrPrompt([
43
- {
44
- role: "system",
45
- content: [
46
- `You are a writer specialized in ${preset.llmPrompt}`,
47
- `Please write detailed drawing instructions and short (2-3 sentences long) speech captions for the ${firstNextOrLast} ${nbPanelsToGenerate} panels (out of ${maxNbPanels} in total) of a new story, but keep it open-ended (it will be continued and expanded later). Please make sure each of those ${nbPanelsToGenerate} panels include info about character gender, age, origin, clothes, colors, location, lights, etc. Only generate those ${nbPanelsToGenerate} panels, but take into account the fact the panels are part of a longer story (${maxNbPanels} panels long).`,
48
- `Give your response as a VALID JSON array like this: \`Array<{ panel: number; instructions: string; caption: string; }>\`.`,
49
- // `Give your response as Markdown bullet points.`,
50
- `Be brief in the instructions and narrative captions of those ${nbPanelsToGenerate} panels, don't add your own comments. The captions must be captivating, smart, entertaining. Be straight to the point, and never reply things like "Sure, I can.." etc. Reply using valid JSON!! Important: Write valid JSON!`
51
- ].filter(item => item).join("\n")
52
- },
53
- {
54
- role: "user",
55
- content: `The story is about: ${prompt}.${existingPanelsTemplate}`,
56
- }
57
  ]) + "\n[{"
58
 
59
-
60
  let result = ""
61
 
62
  // we don't require a lot of token for our task
@@ -66,8 +66,8 @@ export const predictNextPanels = async ({
66
  const nbMaxNewTokens = nbPanelsToGenerate * nbTokensPerPanel
67
 
68
  try {
69
- // console.log(`calling predict(${query}, ${nbTotalPanels})`)
70
- result = `${await predict(query, nbMaxNewTokens)}`.trim()
71
  console.log("LLM result (1st trial):", result)
72
  if (!result.length) {
73
  throw new Error("empty result on 1st trial!")
@@ -78,7 +78,7 @@ export const predictNextPanels = async ({
78
  await sleep(2000)
79
 
80
  try {
81
- result = `${await predict(query + " \n ", nbMaxNewTokens)}`.trim()
82
  console.log("LLM result (2nd trial):", result)
83
  if (!result.length) {
84
  throw new Error("empty result on 2nd trial!")
 
7
  import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
8
  import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
9
  import { sleep } from "@/lib/sleep"
10
+ import { getSystemPrompt } from "./getSystemPrompt"
11
+ import { getUserPrompt } from "./getUserPrompt"
12
 
13
  export const predictNextPanels = async ({
14
  preset,
 
33
  ? ` To help you, here are the previous panels and their captions (note: if you see an anomaly here eg. no caption or the same description repeated multiple times, do not hesitate to fix the story): ${JSON.stringify(existingPanels, null, 2)}`
34
  : ''
35
 
 
36
  const firstNextOrLast =
37
  existingPanels.length === 0
38
  ? "first"
 
40
  ? "last"
41
  : "next"
42
 
43
+ const systemPrompt = getSystemPrompt({
44
+ preset,
45
+ firstNextOrLast,
46
+ maxNbPanels,
47
+ nbPanelsToGenerate,
48
+ })
49
+
50
+ const userPrompt = getUserPrompt({
51
+ prompt,
52
+ existingPanelsTemplate,
53
+ })
54
+
55
+ const zephyPrompt = createZephyrPrompt([
56
+ { role: "system", content: systemPrompt },
57
+ { role: "user", content: userPrompt }
58
  ]) + "\n[{"
59
 
 
60
  let result = ""
61
 
62
  // we don't require a lot of token for our task
 
66
  const nbMaxNewTokens = nbPanelsToGenerate * nbTokensPerPanel
67
 
68
  try {
69
+ // console.log(`calling predict:`, { systemPrompt, userPrompt, nbMaxNewTokens })
70
+ result = `${await predict({ systemPrompt, userPrompt, nbMaxNewTokens })}`.trim()
71
  console.log("LLM result (1st trial):", result)
72
  if (!result.length) {
73
  throw new Error("empty result on 1st trial!")
 
78
  await sleep(2000)
79
 
80
  try {
81
+ result = `${await predict({ systemPrompt: systemPrompt + " \n ", userPrompt, nbMaxNewTokens })}`.trim()
82
  console.log("LLM result (2nd trial):", result)
83
  if (!result.length) {
84
  throw new Error("empty result on 2nd trial!")
src/app/queries/predictWithGroq.ts CHANGED
@@ -2,7 +2,15 @@
2
 
3
  import Groq from "groq-sdk"
4
 
5
- export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
 
 
 
 
 
 
 
 
6
  const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
7
  const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
8
 
@@ -11,7 +19,8 @@ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<s
11
  })
12
 
13
  const messages: Groq.Chat.Completions.CompletionCreateParams.Message[] = [
14
- { role: "assistant", content: inputs },
 
15
  ]
16
 
17
  try {
 
2
 
3
  import Groq from "groq-sdk"
4
 
5
+ export async function predict({
6
+ systemPrompt,
7
+ userPrompt,
8
+ nbMaxNewTokens,
9
+ }: {
10
+ systemPrompt: string
11
+ userPrompt: string
12
+ nbMaxNewTokens: number
13
+ }): Promise<string> {
14
  const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
15
  const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
16
 
 
19
  })
20
 
21
  const messages: Groq.Chat.Completions.CompletionCreateParams.Message[] = [
22
+ { role: "system", content: systemPrompt },
23
+ { role: "user", content: userPrompt },
24
  ]
25
 
26
  try {
src/app/queries/predictWithHuggingFace.ts CHANGED
@@ -2,8 +2,17 @@
2
 
3
  import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
4
  import { LLMEngine } from "@/types"
 
5
 
6
- export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
 
 
 
 
 
 
 
 
7
  const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
8
 
9
  const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
@@ -46,7 +55,12 @@ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<s
46
  try {
47
  for await (const output of api.textGenerationStream({
48
  model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
49
- inputs,
 
 
 
 
 
50
  parameters: {
51
  do_sample: true,
52
  max_new_tokens: nbMaxNewTokens,
 
2
 
3
  import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
4
  import { LLMEngine } from "@/types"
5
+ import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
6
 
7
+ export async function predict({
8
+ systemPrompt,
9
+ userPrompt,
10
+ nbMaxNewTokens,
11
+ }: {
12
+ systemPrompt: string
13
+ userPrompt: string
14
+ nbMaxNewTokens: number
15
+ }): Promise<string> {
16
  const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
17
 
18
  const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
 
55
  try {
56
  for await (const output of api.textGenerationStream({
57
  model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
58
+
59
+ inputs: createZephyrPrompt([
60
+ { role: "system", content: systemPrompt },
61
+ { role: "user", content: userPrompt }
62
+ ]) + "\n[{", // <-- important: we force its hand
63
+
64
  parameters: {
65
  do_sample: true,
66
  max_new_tokens: nbMaxNewTokens,
src/app/queries/predictWithOpenAI.ts CHANGED
@@ -1,9 +1,17 @@
1
  "use server"
2
 
3
- import type { ChatCompletionMessage } from "openai/resources/chat"
4
  import OpenAI from "openai"
5
 
6
- export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
 
 
 
 
 
 
 
 
7
  const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
8
  const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
9
  const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
@@ -13,8 +21,9 @@ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<s
13
  baseURL: openaiApiBaseUrl,
14
  })
15
 
16
- const messages: ChatCompletionMessage[] = [
17
- { role: "assistant", content: inputs },
 
18
  ]
19
 
20
  try {
 
1
  "use server"
2
 
3
+ import type { ChatCompletionMessageParam } from "openai/resources/chat"
4
  import OpenAI from "openai"
5
 
6
+ export async function predict({
7
+ systemPrompt,
8
+ userPrompt,
9
+ nbMaxNewTokens,
10
+ }: {
11
+ systemPrompt: string
12
+ userPrompt: string
13
+ nbMaxNewTokens: number
14
+ }): Promise<string> {
15
  const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
16
  const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
17
  const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
 
21
  baseURL: openaiApiBaseUrl,
22
  })
23
 
24
+ const messages: ChatCompletionMessageParam[] = [
25
+ { role: "system", content: systemPrompt },
26
+ { role: "user", content: userPrompt },
27
  ]
28
 
29
  try {