Skip to content

FastAPI流式异步调用

2024-08-20
python
import os
import contextlib
from typing import TypedDict
from collections.abc import Generator
from collections.abc import AsyncGenerator

import uvicorn
from openai import OpenAI
from openai import Stream
from fastapi import Body
from fastapi import FastAPI
from fastapi import Request
from fastapi.responses import Response
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionChunk
from fastapi.middleware.cors import CORSMiddleware


class LifespanState(TypedDict):
    openai_client: OpenAI


@contextlib.asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[LifespanState, None]:
    openai_client = OpenAI(
        api_key=os.getenv("OPENAI_API_KEY"),
        base_url=os.getenv("OPENAI_BASE_URL"),
    )
    yield {
        "openai_client": openai_client,
    }


app = FastAPI(lifespan=lifespan)

# 允许所有的来源访问,允许所有的方法(GET, POST, PUT, DELETE 等),允许所有的头部
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def generate(client: OpenAI, text: str) -> Generator[str, None, None]:
    stream: Stream[ChatCompletionChunk] = client.chat.completions.create(
        model="yi-medium",
        messages=[{"role": "user", "content": text}],
        stream=True,
    )
    for event in stream:
        current_response = event.choices[0].delta.content
        if current_response:
            yield current_response


@app.post("/chat")
async def chat(
    request: Request,
    text: str = Body(..., embed=True),
) -> Response:
    stream = generate(request.state.openai_client, text)
    return StreamingResponse(
        stream,
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


if __name__ == "__main__":
    uvicorn.run(app=app, host="0.0.0.0", port=8000)