|
import os |
|
import signal |
|
import threading |
|
import time |
|
import subprocess |
|
import asyncio |
|
|
|
OLLAMA = os.path.expanduser("~/ollama") |
|
process = None |
|
OLLAMA_SERVICE_THREAD = None |
|
|
|
if not os.path.exists(OLLAMA): |
|
subprocess.run("curl -L https://ollama.com/download/ollama-linux-amd64 -o ~/ollama", shell=True) |
|
os.chmod(OLLAMA, 0o755) |
|
|
|
def ollama_service_thread(): |
|
global process |
|
process = subprocess.Popen("~/ollama serve", shell=True, preexec_fn=os.setsid) |
|
process.wait() |
|
|
|
def terminate(): |
|
global process, OLLAMA_SERVICE_THREAD |
|
if process: |
|
os.killpg(os.getpgid(process.pid), signal.SIGTERM) |
|
if OLLAMA_SERVICE_THREAD: |
|
OLLAMA_SERVICE_THREAD.join() |
|
process = None |
|
OLLAMA_SERVICE_THREAD = None |
|
print("Ollama service stopped.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ollama |
|
import gradio as gr |
|
from ollama import AsyncClient |
|
client = AsyncClient(host='http://localhost:11434', timeout=120) |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
TITLE = "<h1><center>Ollama Chat</center></h1>" |
|
|
|
DESCRIPTION = f""" |
|
<center> |
|
<p>Feel free to test models with ollama. |
|
<br> |
|
First run please type <em>/init</em> to launch process. |
|
<br> |
|
Type <em>/pull model_name</em> to pull model. |
|
</p> |
|
</center> |
|
""" |
|
|
|
INIT_SIGN = "" |
|
|
|
def init(): |
|
global OLLAMA_SERVICE_THREAD |
|
OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread) |
|
OLLAMA_SERVICE_THREAD.start() |
|
print("Giving ollama serve a moment") |
|
time.sleep(10) |
|
global INIT_SIGN |
|
INIT_SIGN = "FINISHED" |
|
|
|
def ollama_func(command): |
|
if " " in command: |
|
c1, c2 = command.split(" ") |
|
else: |
|
c1 = command |
|
c2 = "" |
|
function_map = { |
|
"/init": init, |
|
"/pull": lambda: ollama.pull(c2), |
|
"/list": ollama.list, |
|
"/bye": terminate, |
|
} |
|
if c1 in function_map: |
|
function_map.get(c1)() |
|
return "Running..." |
|
else: |
|
return "No supported command." |
|
|
|
def launch(): |
|
global OLLAMA_SERVICE_THREAD |
|
OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread) |
|
OLLAMA_SERVICE_THREAD.start() |
|
|
|
|
|
async def stream_chat(message: str, history: list, model: str, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float): |
|
print(f"message: {message}") |
|
conversation = [] |
|
for prompt, answer in history: |
|
conversation.extend([ |
|
{"role": "user", "content": prompt}, |
|
{"role": "assistant", "content": answer}, |
|
]) |
|
conversation.append({"role": "user", "content": message}) |
|
|
|
print(f"Conversation is -\n{conversation}") |
|
|
|
if message.startswith("/"): |
|
resp = ollama_func(message) |
|
yield resp |
|
else: |
|
if not INIT_SIGN: |
|
yield "Please initialize Ollama" |
|
else: |
|
if not process: |
|
launch() |
|
print("Giving ollama serve a moment") |
|
time.sleep(10) |
|
|
|
buffer = "" |
|
async for part in await client.chat( |
|
model=model, |
|
stream=True, |
|
messages=conversation, |
|
keep_alive="60s", |
|
options={ |
|
'num_predict': max_new_tokens, |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'top_k': top_k, |
|
'repeat_penalty': penalty, |
|
'low_vram': True, |
|
}, |
|
): |
|
buffer += part['message']['content'] |
|
yield buffer |
|
|
|
chatbot = gr.Chatbot(height=600, placeholder=DESCRIPTION) |
|
|
|
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: |
|
gr.HTML(TITLE) |
|
gr.ChatInterface( |
|
fn=stream_chat, |
|
chatbot=chatbot, |
|
fill_height=True, |
|
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), |
|
additional_inputs=[ |
|
gr.Textbox( |
|
value="qwen2:0.5b", |
|
label="Model", |
|
render=False, |
|
), |
|
gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
step=0.1, |
|
value=0.8, |
|
label="Temperature", |
|
render=False, |
|
), |
|
gr.Slider( |
|
minimum=128, |
|
maximum=2048, |
|
step=1, |
|
value=1024, |
|
label="Max New Tokens", |
|
render=False, |
|
), |
|
gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.8, |
|
label="top_p", |
|
render=False, |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=20, |
|
label="top_k", |
|
render=False, |
|
), |
|
gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
step=0.1, |
|
value=1.0, |
|
label="Repetition penalty", |
|
render=False, |
|
), |
|
], |
|
examples=[ |
|
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."], |
|
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."], |
|
["Tell me a random fun fact about the Roman Empire."], |
|
["Show me a code snippet of a website's sticky header in CSS and JavaScript."], |
|
], |
|
cache_examples=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|