Gradio Server
2024-09-02
python
import gradio as gr
from openai import OpenAI
# Argument parser setup
parser = argparse.ArgumentParser(
description="Chatbot Interface with Customizable Parameters"
)
parser.add_argument(
# "--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
"--model-url",
type=str,
default="https://api.deepseek.com/v1",
help="Model URL",
)
parser.add_argument(
"-m",
"--model",
type=str,
default="deepseek-chat",
help="Model name for the chatbot",
)
parser.add_argument(
"--temp", type=float, default=0.8, help="Temperature for text generation"
)
parser.add_argument(
"--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
)
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
# Parse the arguments
args = parser.parse_args()
# Set OpenAI's API key and API base to use vLLM's API server.
# openai_api_key = "EMPTY"
openai_api_key = "sk-53d89f130b3f46c2965278b20f16f282"
openai_api_base = args.model_url
# Create an OpenAI client to interact with the API server
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
def predict(
message: str, history: list[tuple[str, str]], system_message: str
) -> Generator[str, None, None]:
# Convert chat history to OpenAI format
history_openai_format = [{"role": "system", "content": system_message}]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
# Create a chat completion request and send it to the API server
stream = client.chat.completions.create(
model=args.model, # Model name to use
messages=history_openai_format, # type: ignore # Chat history
# temperature=args.temp, # Temperature for text generation
stream=True, # Stream response
extra_body={
"repetition_penalty": 1,
"stop_token_ids": (
[
int(id.strip())
for id in args.stop_token_ids.split(",")
if id.strip()
]
if args.stop_token_ids
else []
),
},
max_tokens=2048,
)
# Read and return generated text from response stream
partial_message = ""
for chunk in stream:
partial_message += chunk.choices[0].delta.content or "" # type: ignore
yield partial_message
# Create and launch a chat interface with Gradio
gr.ChatInterface(
predict,
additional_inputs=[
gr.Textbox("you are a helpful assistant", label="System Prompt"),
],
additional_inputs_accordion=gr.Accordion(open=True),
).queue().launch(server_name=args.host, server_port=args.port, share=True)