Skip to content

FastAPI调用

2024-08-20
python
import os
import contextlib
from typing import TypedDict
from collections.abc import AsyncGenerator

import uvicorn
from openai import AsyncOpenAI
from fastapi import Body
from fastapi import FastAPI
from fastapi import Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware


class LifespanState(TypedDict):
    openai_client: AsyncOpenAI


@contextlib.asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[LifespanState, None]:
    openai_client = AsyncOpenAI(
        api_key=os.getenv("OPENAI_API_KEY"),
        base_url=os.getenv("OPENAI_BASE_URL"),
    )
    yield {
        "openai_client": openai_client,
    }


app = FastAPI(lifespan=lifespan)

# 允许所有的来源访问,允许所有的方法(GET, POST, PUT, DELETE 等),允许所有的头部
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


async def generate(client: AsyncOpenAI, text: str) -> str | None:
    response = await client.chat.completions.create(
        model="yi-medium", messages=[{"role": "user", "content": text}], stream=False
    )
    return response.choices[0].message.content


@app.post("/chat")
async def chat(
    request: Request,
    text: str = Body(..., embed=True),
) -> JSONResponse:
    response = await generate(request.state.openai_client, text)
    return JSONResponse({"response": response})


if __name__ == "__main__":
    uvicorn.run(app=app, host="0.0.0.0", port=8000)