File size: 4,246 Bytes
56458a8
cff48f9
56458a8
24c50a4
7175dd2
 
 
 
dd17729
 
afc2656
 
 
 
 
 
 
 
 
 
 
 
56458a8
dd17729
522e040
dd17729
 
 
a5c9736
dd17729
522e040
dd17729
522e040
56458a8
7175dd2
fec39f8
3b617f3
7175dd2
 
 
 
 
 
 
 
 
 
afc2656
 
cff48f9
56458a8
 
 
 
cff48f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56458a8
 
 
 
 
 
e290e46
 
56458a8
 
8f36174
 
dc14f2d
22fd767
1b02803
20c0b53
cff48f9
 
 
 
 
 
 
535e3fa
 
 
cff48f9
535e3fa
 
cff48f9
 
56458a8
535e3fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js

import os 
hf_token = os.environ.get('HF_TOKEN')
from gradio_client import Client
client = Client("https://fffiloni-test-llama-api.hf.space/", hf_token=hf_token)

clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")

def get_text_after_colon(input_text):
    # Find the first occurrence of ":"
    colon_index = input_text.find(":")
    
    # Check if ":" exists in the input_text
    if colon_index != -1:
        # Extract the text after the colon
        result_text = input_text[colon_index + 1:].strip()
        return result_text
    else:
        # Return the original text if ":" is not found
        return input_text

def infer(image_input):
    
    clipi_result = clipi_client.predict(
    				image_input,	# str (filepath or URL to image) in 'parameter_3' Image component
    				"best",	# str in 'Select mode' Radio component
    				4,	# int | float (numeric value between 2 and 24) in 'best mode max flavors' Slider component
    				api_name="/clipi2"
    )
    print(clipi_result)
   

    llama_q = f"""
    I'll give you a simple image caption, from i want you to provide a story that would fit well with the image:
    '{clipi_result[0]}'
    
    """
    
    result = client.predict(
    				llama_q,	# str in 'Message' Textbox component
    				api_name="/predict"
    )

    print(f"Llama2 result: {result}")

    result = get_text_after_colon(result)

    return result, gr.Group.update(visible=True)

css="""
#col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
  animation: spin 1s linear infinite;
}
@keyframes spin {
  from {
      transform: rotate(0deg);
  }
  to {
      transform: rotate(360deg);
  }
}
#share-btn-container {
  display: flex; 
  padding-left: 0.5rem !important; 
  padding-right: 0.5rem !important; 
  background-color: #000000; 
  justify-content: center; 
  align-items: center; 
  border-radius: 9999px !important; 
  max-width: 13rem;
}
div#share-btn-container > div {
    flex-direction: row;
    background: black;
    align-items: center;
}
#share-btn-container:hover {
  background-color: #060606;
}
#share-btn {
  all: initial; 
  color: #ffffff;
  font-weight: 600; 
  cursor:pointer; 
  font-family: 'IBM Plex Sans', sans-serif; 
  margin-left: 0.5rem !important; 
  padding-top: 0.5rem !important; 
  padding-bottom: 0.5rem !important;
  right:0;
}
#share-btn * {
  all: unset;
}
#share-btn-container div:nth-child(-n+2){
  width: auto !important;
  min-height: 0px !important;
}
#share-btn-container .wrap {
  display: none !important;
}
#share-btn-container.hidden {
  display: none!important;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            """
            <h1 style="text-align: center">Image to Story</h1>
            <p style="text-align: center">Upload an image, get a story made by Llama2 !</p>
            """
        )
        with gr.Row():
            with gr.Column():
                image_in = gr.Image(label="Image input", type="filepath", elem_id="image-in")
                submit_btn = gr.Button('Tell me a story')
            with gr.Column():
                #caption = gr.Textbox(label="Generated Caption")
                story = gr.Textbox(label="generated Story", elem_id="story")

                with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
                    community_icon = gr.HTML(community_icon_html)
                    loading_icon = gr.HTML(loading_icon_html)
                    share_button = gr.Button("Share to community", elem_id="share-btn")
        
        gr.Examples(examples=[["./examples/crabby.png"],["./examples/hopper.jpeg"]],
                    fn=infer,
                    inputs=[image_in],
                    outputs=[story, share_group],
                    cache_examples=True
                   )
    submit_btn.click(fn=infer, inputs=[image_in], outputs=[story, share_group])
    share_button.click(None, [], [], _js=share_js)

demo.queue(max_size=12).launch()