掘金 后端 ( ) • 2024-06-26 17:58

FastAPI中自定义白名单黑名单中间件以及底层源码分析

定义白名单Whitelist中间件

from fastapi import FastAPI, Request
from fastapi.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp

app = FastAPI()

WHITELIST = ["127.0.0.1", "172.0.0.1"]

class WhiteIPAddrMiddleware(BaseHTTPMiddleware):
  def __init__(self, app: ASGIApp, whitelist: list[str]):
    super().__init__(app)
    self.whitelist = whitelist
  
  await def dispatch(self, request: Request, call_next):
    client_ip = request.client.host
    if self.whitelist and client_ip not in self.whitelist:
      return Response(status_code=403, content="Forbidden")
    response = await call_next(request)
    return response
   
   
app.add_middleware(WhiteIPAddrMiddleware, whitelist=WHITELIST)


@app.get("/")
async def root():
  return {"msg": "Hello World!"}

定义黑名单Blacklist中间件

from fastapi import FastAPI, Request
from fastapi.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp

app = FastAPI()

BLACKLIST = ["127.0.0.1", "172.0.0.1"]

class BLACKIPAddrMiddleware(BaseHTTPMiddleware):
  def __init__(self, app: ASGIApp, blacklist: list[str]):
    super().__init__(app)
    self.blacklist= blacklist
  
  await def dispatch(self, request: Request, call_next):
    client_ip = request.client.host
    if self.blacklist and client_ip in self.blacklist:
      return Response(status_code=403, content="Forbidden")
    response = await call_next(request)
    return response
   
   
app.add_middleware(BlackIPAddrMiddleware, blacklist=BLACKLIST)


@app.get("/")
async def root():
  return {"msg": "Hello World!"}

代码解析

为什么继承 BaseHTTPMiddleware​是必须的?

BaseHTTPMiddleware​ 是 Starlette 提供的一个类,用于简化中间件的创建。它封装了一些基础的方法和属性,比如 __call__​ 方法来处理请求。

并且对于每一个自定义的中间件,需要创建一个继承BaseHTTPMiddleware类,并覆盖 dispatch​ 方法。这种面向对象的实现方法,使得代码便于扩展和维护。

为什么super().init(app)是必须的?

super().__init__(app)​ 在这里是必须的,因为它确保了父类 BaseHTTPMiddleware​ 的初始化方法被正确调用,这样才能正确设置中间件的基础属性和方法。让我们逐步详细解析为什么以及是如何工作的。

1. 类继承和初始化

在 Python 中,类继承允许一个类(子类)继承另一个类(父类)的属性和方法。子类可以重载父类的方法以实现特定功能,但仍然需要调用父类的方法来保留和初始化父类的属性和功能。

2. BaseHTTPMiddleware​ 的作用

BaseHTTPMiddleware​ 是一个抽象基类,它封装了一些基础的中间件功能。其主要任务是接收请求、处理中间件逻辑,然后将请求传递给下一个中间件或实际的请求处理代码。如果你的中间件类有多层继承,super().__init__​ 的调用可以确保继承链中的每一个类的构造函数都能得到执行。

3. 参数app的传递

BaseHTTPMiddleware​ 的初始化主要包括设置一些基本的属性,比如应用实例 app​,传递app​为参数,使得中间件知道它是属于哪个应用的。

为什么dispatch方法是必须的?

  • 抽象基类模式:

    • dispatch​ 方法是一个抽象方法,意义在于强制要求所有子类必须实现这个方法。这确保了所有继承 BaseHTTPMiddleware​ 的类都有一个用于处理请求的机制。
    • raise NotImplementedError("dispatch must be implemented by child class")​ 表明每个具体的中间件类必须自行实现 dispatch​,否则会报错。
  • 请求处理的核心:

    • dispatch​ 方法是指定在请求进入你的应用前以及处理后被调用的逻辑。在其中,你可以对请求进行检查、修改、记录日志等操作。
    • 子类通过实现 dispatch​ 方法来定义中间件的具体行为。
  • 实现自定义逻辑:

    • 通过编写 dispatch​ 方法,开发者可以插入自定义逻辑来处理请求。比如IP白名单过滤、黑名单过滤、权限检查等。
    • 这是实现你自定义中间件功能的主要地方。

简化版本的 BaseHTTPMiddleware​ 可能看起来像这样:

class BaseHTTPMiddleware:
    def __init__(self, app):
        self.app = app

    async def __call__(self, scope, receive, send):
        assert scope['type'] == 'http'
        response = await self.dispatch(scope, receive, send)
        if response is None:
            await self.app(scope, receive, send)
        else:
            await response(scope, receive, send)

    async def dispatch(self, scope, receive, send):
        raise NotImplementedError("dispatch must be implemented by child class")
  • __init__方法: 初始化中间件,设置中间件所属的应用实例 self.app​。
  • __call__方法: 这是 ASGI 应用的入口点,用于处理传入的请求。这部分代码会先调用 dispatch​ 方法进行请求处理。
  • dispatch方法: 必须由子类实现的抽象方法,用于定义中间件具体的请求处理逻辑。

BaseHTTPMiddleware源码分析

# starlette/middleware/base.py

import typing

import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
    [Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


class BaseHTTPMiddleware:
    def __init__(
        self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
    ) -> None:
        self.app = app
        self.dispatch_func = self.dispatch if dispatch is None else dispatch

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        response_sent = anyio.Event()

        async def call_next(request: Request) -> Response:
            app_exc: typing.Optional[Exception] = None
            send_stream, recv_stream = anyio.create_memory_object_stream()

            async def receive_or_disconnect() -> Message:
                if response_sent.is_set():
                    return {"type": "http.disconnect"}

                async with anyio.create_task_group() as task_group:

                    async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
                        result = await func()
                        task_group.cancel_scope.cancel()
                        return result

                    task_group.start_soon(wrap, response_sent.wait)
                    message = await wrap(request.receive)

                if response_sent.is_set():
                    return {"type": "http.disconnect"}

                return message

            async def close_recv_stream_on_response_sent() -> None:
                await response_sent.wait()
                recv_stream.close()

            async def send_no_error(message: Message) -> None:
                try:
                    await send_stream.send(message)
                except anyio.BrokenResourceError:
                    # recv_stream has been closed, i.e. response_sent has been set.
                    return

            async def coro() -> None:
                nonlocal app_exc

                async with send_stream:
                    try:
                        await self.app(scope, receive_or_disconnect, send_no_error)
                    except Exception as exc:
                        app_exc = exc

            task_group.start_soon(close_recv_stream_on_response_sent)
            task_group.start_soon(coro)

            try:
                message = await recv_stream.receive()
                info = message.get("info", None)
                if message["type"] == "http.response.debug" and info is not None:
                    message = await recv_stream.receive()
            except anyio.EndOfStream:
                if app_exc is not None:
                    raise app_exc
                raise RuntimeError("No response returned.")

            assert message["type"] == "http.response.start"

            async def body_stream() -> typing.AsyncGenerator[bytes, None]:
                async with recv_stream:
                    async for message in recv_stream:
                        assert message["type"] == "http.response.body"
                        body = message.get("body", b"")
                        if body:
                            yield body

                if app_exc is not None:
                    raise app_exc

            response = _StreamingResponse(
                status_code=message["status"], content=body_stream(), info=info
            )
            response.raw_headers = message["headers"]
            return response

        async with anyio.create_task_group() as task_group:
            request = Request(scope, receive=receive)
            response = await self.dispatch_func(request, call_next)
            await response(scope, receive, send)
            response_sent.set()

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        raise NotImplementedError()  # pragma: no cover

这段代码实现了一个基于Starlette框架的HTTP中间件类BaseHTTPMiddleware​。它的主要功能是拦截HTTP请求并在请求处理前后执行一些自定义逻辑。以下是对代码的详细分析:

导入模块

import typing
import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
  • typing​:用于类型提示。
  • anyio​:一个异步I/O库,提供了对异步任务和流的支持。anyio是什么
  • starlette​相关模块:用于处理HTTP请求和响应。

类型定义

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
    [Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")
  • RequestResponseEndpoint​:一个函数类型,接收一个Request​对象并返回一个Response​对象。
  • DispatchFunction​:一个函数类型,接收一个Request​对象和一个RequestResponseEndpoint​,返回一个Response​对象。
  • T​:一个泛型类型变量。

BaseHTTPMiddleware​类

初始化方法

class BaseHTTPMiddleware:
    def __init__(
        self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
    ) -> None:
        self.app = app
        self.dispatch_func = self.dispatch if dispatch is None else dispatch
  • app​:ASGI应用程序实例。
  • dispatch​:可选的调度函数,如果未提供,则使用类的dispatch​方法。

__call__​方法

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
    if scope["type"] != "http":
        await self.app(scope, receive, send)
        return

    response_sent = anyio.Event()
  • 检查请求类型是否为HTTP,如果不是,则直接调用下一个中间件或应用程序。
  • 创建一个anyio.Event​对象,用于在响应发送后通知其他任务。

call_next​函数

async def call_next(request: Request) -> Response:
    app_exc: typing.Optional[Exception] = None
    send_stream, recv_stream = anyio.create_memory_object_stream()
  • call_next​函数用于调用下一个中间件或应用程序,并处理请求和响应。
  • 创建两个内存对象流send_stream​和recv_stream​,用于在任务之间传递消息。
receive_or_disconnect​函数
async def receive_or_disconnect() -> Message:
    if response_sent.is_set():
        return {"type": "http.disconnect"}

    async with anyio.create_task_group() as task_group:

        async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
            result = await func()
            task_group.cancel_scope.cancel()
            return result

        task_group.start_soon(wrap, response_sent.wait)
        message = await wrap(request.receive)

    if response_sent.is_set():
        return {"type": "http.disconnect"}

    return message
  • receive_or_disconnect​函数用于接收请求消息或断开连接。
  • 如果响应已经发送,则返回断开连接消息。
  • 使用anyio.create_task_group​创建任务组,并启动两个任务:等待响应发送和接收请求消息。
close_recv_stream_on_response_sent​函数
async def close_recv_stream_on_response_sent() -> None:
    await response_sent.wait()
    recv_stream.close()
  • 等待响应发送后关闭接收流。
send_no_error​函数
async def send_no_error(message: Message) -> None:
    try:
        await send_stream.send(message)
    except anyio.BrokenResourceError:
        return
  • 发送消息到发送流,如果发送流已经关闭,则忽略错误。
coro​函数
async def coro() -> None:
    nonlocal app_exc

    async with send_stream:
        try:
            await self.app(scope, receive_or_disconnect, send_no_error)
        except Exception as exc:
            app_exc = exc
  • coro​函数用于调用下一个中间件或应用程序,并捕获任何异常。
启动任务组
task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)
  • 启动两个任务:等待响应发送后关闭接收流和调用下一个中间件或应用程序。
处理响应消息
try:
    message = await recv_stream.receive()
    info = message.get("info", None)
    if message["type"] == "http.response.debug" and info is not None:
        message = await recv_stream.receive()
except anyio.EndOfStream:
    if app_exc is not None:
        raise app_exc
    raise RuntimeError("No response returned.")

assert message["type"] == "http.response.start"
  • 接收响应消息并处理,如果没有响应返回,则抛出异常。
body_stream​生成器
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
    async with recv_stream:
        async for message in recv_stream:
            assert message["type"] == "http.response.body"
            body = message.get("body", b"")
            if body:
                yield body

    if app_exc is not None:
        raise app_exc
  • body_stream​生成器用于生成响应体的字节流。
创建响应对象
response = _StreamingResponse(
    status_code=message["status"], content=body_stream(), info=info
)
response.raw_headers = message["headers"]
return response
  • 创建一个_StreamingResponse​对象,并设置响应头和状态码。

调用调度函数

async with anyio.create_task_group() as task_group:
    request = Request(scope, receive=receive)
    response = await self.dispatch_func(request, call_next)
    await response(scope, receive, send)
    response_sent.set()
  • 创建任务组并调用调度函数,处理请求并发送响应。

dispatch​方法

async def dispatch(
    self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
    raise NotImplementedError()  # pragma: no cover
  • dispatch​方法是一个抽象方法,子类需要实现该方法以定义自定义的请求处理逻辑。

执行流程

这段代码的执行流程可以分为以下几个步骤:

  1. 初始化中间件

    • BaseHTTPMiddleware​类的实例化时,__init__​方法被调用,初始化app​和dispatch_func​属性。
  2. 调用中间件

    • 当一个HTTP请求到达时,__call__​方法被调用。
    • 如果请求类型不是HTTP(例如WebSocket),则直接调用下一个中间件或应用程序。
    • 创建一个anyio.Event​对象response_sent​,用于在响应发送后通知其他任务。
  3. 定义内部函数call_next​:

    • call_next​函数用于调用下一个中间件或应用程序,并处理请求和响应。
    • 创建两个内存对象流send_stream​和recv_stream​,用于在任务之间传递消息。
  4. 定义内部函数receive_or_disconnect​:

    • receive_or_disconnect​函数用于接收请求消息或断开连接。
    • 如果响应已经发送,则返回断开连接消息,即type​为http.disconnect​。
    • 使用anyio.create_task_group​创建任务组,并启动两个任务:等待响应发送和接收请求消息。
  5. 定义内部函数close_recv_stream_on_response_sent​:

    • 等待响应发送后关闭接收流。
  6. 定义内部函数send_no_error​:

    • 发送消息到发送流,如果发送流已经关闭,则忽略错误。
  7. 定义内部函数coro​:

    • coro​函数用于调用下一个中间件或应用程序,并捕获任何异常。
  8. 启动任务组

    • 启动两个任务:等待响应发送后关闭接收流和调用下一个中间件或应用程序。
  9. 处理响应消息

    • 接收响应消息并处理,如果没有响应返回,则抛出异常。
  10. 定义生成器body_stream​:

    • body_stream​生成器用于生成响应体的字节流。
  11. 创建响应对象

    • 创建一个_StreamingResponse​对象,并设置响应头和状态码。
  12. 调用调度函数

    • 创建任务组并调用调度函数,处理请求并发送响应。
  13. 设置响应发送事件

    • 调用response_sent.set()​,通知其他任务响应已经发送。