jbilcke-hf HF staff commited on
Commit
1cf03f7
·
1 Parent(s): 6eef442

switching to a cheaper interpolation engine

Browse files
.env CHANGED
@@ -1,13 +1,18 @@
1
  AUTH_REPLICATE_API_TOKEN="<USE YOUR OWN>"
2
  AUTH_HOTSHOT_XL_API_GRADIO_ACCESS_TOKEN="<USE YOUR OWN>"
 
3
 
4
- # TODO: those could be
5
  # VIDEO_HOTSHOT_XL_API_OFFICIAL
6
- # VIDEO_HOTSHOT_XL_API_NODE
7
  # VIDEO_HOTSHOT_XL_API_GRADIO
8
  # VIDEO_HOTSHOT_XL_API_REPLICATE
9
  VIDEO_ENGINE="VIDEO_HOTSHOT_XL_API_GRADIO"
10
 
 
 
 
 
 
11
  # the official API developed by the Hotshot-XL team
12
  # note: it isn't released yet
13
  VIDEO_HOTSHOT_XL_API_OFFICIAL=""
@@ -25,9 +30,13 @@ VIDEO_HOTSHOT_XL_API_GRADIO="https://jbilcke-hf-hotshot-xl-server-1.hf.space"
25
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL="cloneofsimo/hotshot-xl-lora-controlnet"
26
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION="75e26ffd033a59a78954a3d675632f47f7f8470402aec51c255b9f9b7b62568b"
27
 
 
 
28
  INTERPOLATION_API_REPLICATE_MODEL="zsxkib/st-mfnet"
29
  INTERPOLATION_API_REPLICATE_MODEL_VERSION="faa7693430b0a4ac95d1b8e25165673c1d7a7263537a7c4bb9be82a3e2d130fb"
30
 
 
 
31
  # ----------- RATE LIMIT -------
32
  ENABLE_RATE_LIMIT=""
33
  UPSTASH_REDIS_REST_URL="<USE YOUR OWN>"
 
1
  AUTH_REPLICATE_API_TOKEN="<USE YOUR OWN>"
2
  AUTH_HOTSHOT_XL_API_GRADIO_ACCESS_TOKEN="<USE YOUR OWN>"
3
+ AUTH_INTERPOLATION_API_GRADIO_TOKEN="<USE YOUR OWN>"
4
 
5
+ # TODO: support multiple backends
6
  # VIDEO_HOTSHOT_XL_API_OFFICIAL
 
7
  # VIDEO_HOTSHOT_XL_API_GRADIO
8
  # VIDEO_HOTSHOT_XL_API_REPLICATE
9
  VIDEO_ENGINE="VIDEO_HOTSHOT_XL_API_GRADIO"
10
 
11
+ # TODO: support multiple backends
12
+ # STMFNET_REPLICATE
13
+ # FILM_GRADIO
14
+ INTERPOLATION_ENGINE="STMFNET_REPLICATE"
15
+
16
  # the official API developed by the Hotshot-XL team
17
  # note: it isn't released yet
18
  VIDEO_HOTSHOT_XL_API_OFFICIAL=""
 
30
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL="cloneofsimo/hotshot-xl-lora-controlnet"
31
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION="75e26ffd033a59a78954a3d675632f47f7f8470402aec51c255b9f9b7b62568b"
32
 
33
+ # ----------- INTERPOLATION ------
34
+
35
  INTERPOLATION_API_REPLICATE_MODEL="zsxkib/st-mfnet"
36
  INTERPOLATION_API_REPLICATE_MODEL_VERSION="faa7693430b0a4ac95d1b8e25165673c1d7a7263537a7c4bb9be82a3e2d130fb"
37
 
38
+ INTERPOLATION_API_GRADIO_URL="https://jbilcke-hf-video-interpolation-server.hf.space"
39
+
40
  # ----------- RATE LIMIT -------
41
  ENABLE_RATE_LIMIT=""
42
  UPSTASH_REDIS_REST_URL="<USE YOUR OWN>"
package-lock.json CHANGED
The diff for this file is too large to render. See raw diff
 
public/images/models/sdxl-akira.jpg ADDED
public/images/models/sdxl-starfield.jpg ADDED
public/images/sign-in-with-huggingface-xl.svg ADDED
src/app/interface/auth-dialog/index.tsx ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export function AuthDialog() {
2
+ return (
3
+ <>
4
+ </>
5
+ )
6
+ }
src/app/interface/generate/index.tsx CHANGED
@@ -9,12 +9,13 @@ import { cn } from "@/lib/utils"
9
  import { headingFont } from "@/app/interface/fonts"
10
  import { useCharacterLimit } from "@/lib/useCharacterLimit"
11
  import { generateAnimation } from "@/app/server/actions/animation"
 
12
  import { getLatestPosts, getPost, postToCommunity } from "@/app/server/actions/community"
13
  import { getSDXLModels } from "@/app/server/actions/models"
14
  import { HotshotImageInferenceSize, Post, QualityLevel, QualityOption, SDXLModel } from "@/types"
15
  import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"
16
  import { TooltipProvider } from "@radix-ui/react-tooltip"
17
- import { interpolate } from "@/app/server/actions/interpolate"
18
  import { isRateLimitError } from "@/app/server/utils/isRateLimitError"
19
  import { useCountdown } from "@/lib/useCountdown"
20
 
@@ -212,7 +213,7 @@ export function Generate() {
212
  setRuns(runsRef.current + 1)
213
 
214
  try {
215
- assetUrl = await interpolate(rawAssetUrl)
216
 
217
  if (!assetUrl) {
218
  throw new Error("invalid interpolated asset url")
 
9
  import { headingFont } from "@/app/interface/fonts"
10
  import { useCharacterLimit } from "@/lib/useCharacterLimit"
11
  import { generateAnimation } from "@/app/server/actions/animation"
12
+ import { interpolateVideo } from "@/app/server/actions/interpolation"
13
  import { getLatestPosts, getPost, postToCommunity } from "@/app/server/actions/community"
14
  import { getSDXLModels } from "@/app/server/actions/models"
15
  import { HotshotImageInferenceSize, Post, QualityLevel, QualityOption, SDXLModel } from "@/types"
16
  import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"
17
  import { TooltipProvider } from "@radix-ui/react-tooltip"
18
+
19
  import { isRateLimitError } from "@/app/server/utils/isRateLimitError"
20
  import { useCountdown } from "@/lib/useCountdown"
21
 
 
213
  setRuns(runsRef.current + 1)
214
 
215
  try {
216
+ assetUrl = await interpolateVideo(rawAssetUrl)
217
 
218
  if (!assetUrl) {
219
  throw new Error("invalid interpolated asset url")
src/app/server/actions/animation.ts CHANGED
@@ -5,8 +5,8 @@ import {Redis} from "@upstash/redis"
5
 
6
  import { VideoOptions } from "@/types"
7
 
8
- import { generateVideoWithGradioAPI } from "./generateWithGradioApi"
9
- import { generateVideoWithReplicateAPI } from "./generateWithReplicateAPI"
10
  import { filterOutBadWords } from "./censorship"
11
 
12
  const videoEngine = `${process.env.VIDEO_ENGINE || ""}`
@@ -102,7 +102,7 @@ export async function generateAnimation({
102
  try {
103
 
104
  if (videoEngine === "VIDEO_HOTSHOT_XL_API_REPLICATE") {
105
- return generateVideoWithReplicateAPI({
106
  positivePrompt,
107
  negativePrompt,
108
  size,
@@ -144,7 +144,7 @@ export async function generateAnimation({
144
 
145
  return content
146
  } else if (videoEngine === "VIDEO_HOTSHOT_XL_API_GRADIO") {
147
- return generateVideoWithGradioAPI({
148
  positivePrompt,
149
  negativePrompt,
150
  size,
 
5
 
6
  import { VideoOptions } from "@/types"
7
 
8
+ import { generateGradio } from "./generateGradio"
9
+ import { generateReplicate } from "./generateReplicate"
10
  import { filterOutBadWords } from "./censorship"
11
 
12
  const videoEngine = `${process.env.VIDEO_ENGINE || ""}`
 
102
  try {
103
 
104
  if (videoEngine === "VIDEO_HOTSHOT_XL_API_REPLICATE") {
105
+ return generateReplicate({
106
  positivePrompt,
107
  negativePrompt,
108
  size,
 
144
 
145
  return content
146
  } else if (videoEngine === "VIDEO_HOTSHOT_XL_API_GRADIO") {
147
+ return generateGradio({
148
  positivePrompt,
149
  negativePrompt,
150
  size,
src/app/server/actions/{generateWithGradioApi.ts → generateGradio.ts} RENAMED
@@ -4,7 +4,7 @@ import { VideoOptions } from "@/types"
4
  const gradioApi = `${process.env.VIDEO_HOTSHOT_XL_API_GRADIO || ""}`
5
  const accessToken = `${process.env.AUTH_HOTSHOT_XL_API_GRADIO_ACCESS_TOKEN || ""}`
6
 
7
- export async function generateVideoWithGradioAPI({
8
  positivePrompt = "",
9
  negativePrompt = "",
10
  size = "512x512",
 
4
  const gradioApi = `${process.env.VIDEO_HOTSHOT_XL_API_GRADIO || ""}`
5
  const accessToken = `${process.env.AUTH_HOTSHOT_XL_API_GRADIO_ACCESS_TOKEN || ""}`
6
 
7
+ export async function generateGradio({
8
  positivePrompt = "",
9
  negativePrompt = "",
10
  size = "512x512",
src/app/server/actions/{generateWithReplicateAPI.ts → generateReplicate.ts} RENAMED
@@ -8,7 +8,7 @@ const replicateToken = `${process.env.AUTH_REPLICATE_API_TOKEN || ""}`
8
  const replicateModel = `${process.env.VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL || ""}`
9
  const replicateModelVersion = `${process.env.VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION || ""}`
10
 
11
- export async function generateVideoWithReplicateAPI({
12
  positivePrompt = "",
13
  negativePrompt = "",
14
  size = "512x512",
 
8
  const replicateModel = `${process.env.VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL || ""}`
9
  const replicateModelVersion = `${process.env.VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION || ""}`
10
 
11
+ export async function generateReplicate({
12
  positivePrompt = "",
13
  negativePrompt = "",
14
  size = "512x512",
src/app/server/actions/interpolateGradio.ts ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ const gradioApi = `${process.env.INTERPOLATION_API_GRADIO_URL || ""}`
3
+ const accessToken = `${process.env.AUTH_INTERPOLATION_API_GRADIO_TOKEN || ""}`
4
+
5
+ export async function interpolateGradio(assetUrl: string): Promise<string> {
6
+ // we need to remove this header perhaps
7
+ const videoInBase64 = assetUrl.split("data:video/mp4;base64,").pop()
8
+
9
+ const interpolationSteps = 2
10
+ const nbFramesPerSecond = 32
11
+
12
+ const res = await fetch(gradioApi + (gradioApi.endsWith("/") ? "" : "/") + "api/predict", {
13
+ method: "POST",
14
+ headers: {
15
+ "Content-Type": "application/json",
16
+ // Authorization: `Bearer ${token}`,
17
+ },
18
+ body: JSON.stringify({
19
+ fn_index: 0, // <- important!
20
+ data: [
21
+ accessToken,
22
+ videoInBase64,
23
+ interpolationSteps,
24
+ nbFramesPerSecond
25
+ ],
26
+ }),
27
+ cache: "no-store",
28
+ // we can also use this (see https://vercel.com/blog/vercel-cache-api-nextjs-cache)
29
+ // next: { revalidate: 1 }
30
+ })
31
+
32
+ const { data } = await res.json()
33
+
34
+ if (res.status !== 200 || !data[0]?.length) {
35
+ // This will activate the closest `error.js` Error Boundary
36
+ throw new Error(`Failed to fetch data (status: ${res.status})`)
37
+ }
38
+
39
+ return data[0]
40
+ }
41
+
src/app/server/actions/{interpolate.ts → interpolateReplicate.ts} RENAMED
@@ -1,5 +1,3 @@
1
- "use server"
2
-
3
  import Replicate from "replicate"
4
 
5
  import { sleep } from "@/lib/sleep"
@@ -8,7 +6,7 @@ const replicateToken = `${process.env.AUTH_REPLICATE_API_TOKEN || ""}`
8
  const replicateModel = `${process.env.INTERPOLATION_API_REPLICATE_MODEL || ""}`
9
  const replicateModelVersion = `${process.env.INTERPOLATION_API_REPLICATE_MODEL_VERSION || ""}`
10
 
11
- export async function interpolate(input: string): Promise<string> {
12
  if (!replicateToken) {
13
  throw new Error(`you need to configure your AUTH_REPLICATE_API_TOKEN in order to use interpolation`)
14
  }
 
 
 
1
  import Replicate from "replicate"
2
 
3
  import { sleep } from "@/lib/sleep"
 
6
  const replicateModel = `${process.env.INTERPOLATION_API_REPLICATE_MODEL || ""}`
7
  const replicateModelVersion = `${process.env.INTERPOLATION_API_REPLICATE_MODEL_VERSION || ""}`
8
 
9
+ export async function interpolateReplicate(input: string): Promise<string> {
10
  if (!replicateToken) {
11
  throw new Error(`you need to configure your AUTH_REPLICATE_API_TOKEN in order to use interpolation`)
12
  }
src/app/server/actions/interpolation.ts ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "use server"
2
+
3
+ import { interpolateGradio } from "./interpolateGradio"
4
+ import { interpolateReplicate } from "./interpolateReplicate"
5
+
6
+ const interpolationEngine = `${process.env.INTERPOLATION_ENGINE || ""}`
7
+
8
+ export async function interpolateVideo(inputVideo: string): Promise<string> {
9
+ if (!inputVideo?.length) {
10
+ throw new Error(`missing input video`)
11
+ }
12
+
13
+ try {
14
+
15
+ if (interpolationEngine === "STMFNET_REPLICATE") {
16
+ return interpolateReplicate(inputVideo)
17
+ } else if (interpolationEngine === "FILM_GRADIO") {
18
+ return interpolateGradio(inputVideo)
19
+ } else {
20
+ throw new Error(`unsupported interpolation engine "${interpolationEngine}"`)
21
+ }
22
+ } catch (err) {
23
+ throw new Error(`failed to interpolate the video ${err}`)
24
+ }
25
+ }
src/app/server/actions/models.ts CHANGED
@@ -22,6 +22,26 @@ export async function getSDXLModels(): Promise<SDXLModel[]> {
22
  const compatibleModels = content.filter(model => model.is_compatible)
23
 
24
  const hardcoded: SDXLModel[] = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  {
26
  "image": "https://jbilcke-hf-ai-clip-factory.hf.space/images/models/sdxl-cyberpunk-2077.jpg",
27
  "title": "sdxl-cyberpunk-2077",
 
22
  const compatibleModels = content.filter(model => model.is_compatible)
23
 
24
  const hardcoded: SDXLModel[] = [
25
+ {
26
+ "image": "https://jbilcke-hf-ai-clip-factory.hf.space/images/models/sdxl-starfield.jpg",
27
+ "title": "sdxl-starfield",
28
+ "repo": "jbilcke-hf/sdxl-starfield",
29
+ "trigger_word": "starfield-style",
30
+ "weights": "pytorch_lora_weights.safetensors",
31
+ "is_compatible": true,
32
+ "likes": 0,
33
+ "downloads": 0
34
+ },
35
+ {
36
+ "image": "https://jbilcke-hf-ai-clip-factory.hf.space/images/models/sdxl-akira.jpg",
37
+ "title": "sdxl-akira",
38
+ "repo": "jbilcke-hf/sdxl-akira",
39
+ "trigger_word": "akira-style",
40
+ "weights": "pytorch_lora_weights.safetensors",
41
+ "is_compatible": true,
42
+ "likes": 0,
43
+ "downloads": 0
44
+ },
45
  {
46
  "image": "https://jbilcke-hf-ai-clip-factory.hf.space/images/models/sdxl-cyberpunk-2077.jpg",
47
  "title": "sdxl-cyberpunk-2077",
src/components/ui/dialog.tsx CHANGED
@@ -11,10 +11,9 @@ const Dialog = DialogPrimitive.Root
11
  const DialogTrigger = DialogPrimitive.Trigger
12
 
13
  const DialogPortal = ({
14
- className,
15
  ...props
16
  }: DialogPrimitive.DialogPortalProps) => (
17
- <DialogPrimitive.Portal className={cn(className)} {...props} />
18
  )
19
  DialogPortal.displayName = DialogPrimitive.Portal.displayName
20
 
 
11
  const DialogTrigger = DialogPrimitive.Trigger
12
 
13
  const DialogPortal = ({
 
14
  ...props
15
  }: DialogPrimitive.DialogPortalProps) => (
16
+ <DialogPrimitive.Portal {...props} />
17
  )
18
  DialogPortal.displayName = DialogPrimitive.Portal.displayName
19