Skip to content

FastAPI流式调用

2024-08-20
python
import os
import contextlib
from typing import TypedDict
from collections.abc import AsyncGenerator

import uvicorn
from openai import AsyncOpenAI
from openai import AsyncStream
from fastapi import Body
from fastapi import FastAPI
from fastapi import Request
from fastapi.responses import Response
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionChunk
from fastapi.middleware.cors import CORSMiddleware


class LifespanState(TypedDict):
    openai_client: AsyncOpenAI


@contextlib.asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[LifespanState, None]:
    openai_client = AsyncOpenAI(
        api_key=os.getenv("OPENAI_API_KEY"),
        base_url=os.getenv("OPENAI_BASE_URL"),
    )
    yield {
        "openai_client": openai_client,
    }


app = FastAPI(lifespan=lifespan)

# 允许所有的来源访问,允许所有的方法(GET, POST, PUT, DELETE 等),允许所有的头部
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


async def generate(client: AsyncOpenAI, text: str) -> AsyncGenerator[str, None]:
    stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
        model="yi-medium",
        messages=[{"role": "user", "content": text}],
        stream=True,
    )
    async for event in stream:
        current_response = event.choices[0].delta.content
        if current_response:
            yield current_response


@app.post("/chat")
async def chat(
    request: Request,
    text: str = Body(..., embed=True),
) -> Response:
    stream = generate(request.state.openai_client, text)
    return StreamingResponse(
        stream,
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


if __name__ == "__main__":
    uvicorn.run(app=app, host="0.0.0.0", port=8000)