FastAPI中自定义白名单黑名单中间件以及底层源码分析
定义白名单Whitelist中间件
from fastapi import FastAPI, Request
from fastapi.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
app = FastAPI()
WHITELIST = ["127.0.0.1", "172.0.0.1"]
class WhiteIPAddrMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, whitelist: list[str]):
super().__init__(app)
self.whitelist = whitelist
await def dispatch(self, request: Request, call_next):
client_ip = request.client.host
if self.whitelist and client_ip not in self.whitelist:
return Response(status_code=403, content="Forbidden")
response = await call_next(request)
return response
app.add_middleware(WhiteIPAddrMiddleware, whitelist=WHITELIST)
@app.get("/")
async def root():
return {"msg": "Hello World!"}
定义黑名单Blacklist中间件
from fastapi import FastAPI, Request
from fastapi.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
app = FastAPI()
BLACKLIST = ["127.0.0.1", "172.0.0.1"]
class BLACKIPAddrMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, blacklist: list[str]):
super().__init__(app)
self.blacklist= blacklist
await def dispatch(self, request: Request, call_next):
client_ip = request.client.host
if self.blacklist and client_ip in self.blacklist:
return Response(status_code=403, content="Forbidden")
response = await call_next(request)
return response
app.add_middleware(BlackIPAddrMiddleware, blacklist=BLACKLIST)
@app.get("/")
async def root():
return {"msg": "Hello World!"}
代码解析
为什么继承 BaseHTTPMiddleware
是必须的?
BaseHTTPMiddleware
是 Starlette 提供的一个类,用于简化中间件的创建。它封装了一些基础的方法和属性,比如 __call__
方法来处理请求。
并且对于每一个自定义的中间件,需要创建一个继承BaseHTTPMiddleware类,并覆盖 dispatch
方法。这种面向对象的实现方法,使得代码便于扩展和维护。
为什么super().init(app)是必须的?
super().__init__(app)
在这里是必须的,因为它确保了父类 BaseHTTPMiddleware
的初始化方法被正确调用,这样才能正确设置中间件的基础属性和方法。让我们逐步详细解析为什么以及是如何工作的。
1. 类继承和初始化
在 Python 中,类继承允许一个类(子类)继承另一个类(父类)的属性和方法。子类可以重载父类的方法以实现特定功能,但仍然需要调用父类的方法来保留和初始化父类的属性和功能。
2. BaseHTTPMiddleware
的作用
BaseHTTPMiddleware
是一个抽象基类,它封装了一些基础的中间件功能。其主要任务是接收请求、处理中间件逻辑,然后将请求传递给下一个中间件或实际的请求处理代码。如果你的中间件类有多层继承,super().__init__
的调用可以确保继承链中的每一个类的构造函数都能得到执行。
3. 参数app的传递
BaseHTTPMiddleware
的初始化主要包括设置一些基本的属性,比如应用实例 app
,传递app
为参数,使得中间件知道它是属于哪个应用的。
为什么dispatch方法是必须的?
-
抽象基类模式:
-
dispatch
方法是一个抽象方法,意义在于强制要求所有子类必须实现这个方法。这确保了所有继承BaseHTTPMiddleware
的类都有一个用于处理请求的机制。 -
raise NotImplementedError("dispatch must be implemented by child class")
表明每个具体的中间件类必须自行实现dispatch
,否则会报错。
-
-
请求处理的核心:
-
dispatch
方法是指定在请求进入你的应用前以及处理后被调用的逻辑。在其中,你可以对请求进行检查、修改、记录日志等操作。 - 子类通过实现
dispatch
方法来定义中间件的具体行为。
-
-
实现自定义逻辑:
- 通过编写
dispatch
方法,开发者可以插入自定义逻辑来处理请求。比如IP白名单过滤、黑名单过滤、权限检查等。 - 这是实现你自定义中间件功能的主要地方。
- 通过编写
简化版本的 BaseHTTPMiddleware
可能看起来像这样:
class BaseHTTPMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
assert scope['type'] == 'http'
response = await self.dispatch(scope, receive, send)
if response is None:
await self.app(scope, receive, send)
else:
await response(scope, receive, send)
async def dispatch(self, scope, receive, send):
raise NotImplementedError("dispatch must be implemented by child class")
-
__init__
方法: 初始化中间件,设置中间件所属的应用实例self.app
。 -
__call__
方法: 这是 ASGI 应用的入口点,用于处理传入的请求。这部分代码会先调用dispatch
方法进行请求处理。 -
dispatch
方法: 必须由子类实现的抽象方法,用于定义中间件具体的请求处理逻辑。
BaseHTTPMiddleware源码分析
# starlette/middleware/base.py
import typing
import anyio
from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
response_sent = anyio.Event()
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result
task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
if response_sent.is_set():
return {"type": "http.disconnect"}
return message
async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()
async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return
async def coro() -> None:
nonlocal app_exc
async with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)
try:
message = await recv_stream.receive()
info = message.get("info", None)
if message["type"] == "http.response.debug" and info is not None:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if app_exc is not None:
raise app_exc
response = _StreamingResponse(
status_code=message["status"], content=body_stream(), info=info
)
response.raw_headers = message["headers"]
return response
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
response_sent.set()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover
这段代码实现了一个基于Starlette框架的HTTP中间件类BaseHTTPMiddleware
。它的主要功能是拦截HTTP请求并在请求处理前后执行一些自定义逻辑。以下是对代码的详细分析:
导入模块
import typing
import anyio
from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
-
typing
:用于类型提示。 -
anyio
:一个异步I/O库,提供了对异步任务和流的支持。anyio是什么 -
starlette
相关模块:用于处理HTTP请求和响应。
类型定义
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")
-
RequestResponseEndpoint
:一个函数类型,接收一个Request
对象并返回一个Response
对象。 -
DispatchFunction
:一个函数类型,接收一个Request
对象和一个RequestResponseEndpoint
,返回一个Response
对象。 -
T
:一个泛型类型变量。
BaseHTTPMiddleware
类
初始化方法
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
-
app
:ASGI应用程序实例。 -
dispatch
:可选的调度函数,如果未提供,则使用类的dispatch
方法。
__call__
方法
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
response_sent = anyio.Event()
- 检查请求类型是否为HTTP,如果不是,则直接调用下一个中间件或应用程序。
- 创建一个
anyio.Event
对象,用于在响应发送后通知其他任务。
call_next
函数
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
-
call_next
函数用于调用下一个中间件或应用程序,并处理请求和响应。 - 创建两个内存对象流
send_stream
和recv_stream
,用于在任务之间传递消息。
receive_or_disconnect
函数
async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result
task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
if response_sent.is_set():
return {"type": "http.disconnect"}
return message
-
receive_or_disconnect
函数用于接收请求消息或断开连接。 - 如果响应已经发送,则返回断开连接消息。
- 使用
anyio.create_task_group
创建任务组,并启动两个任务:等待响应发送和接收请求消息。
close_recv_stream_on_response_sent
函数
async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()
- 等待响应发送后关闭接收流。
send_no_error
函数
async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
return
- 发送消息到发送流,如果发送流已经关闭,则忽略错误。
coro
函数
async def coro() -> None:
nonlocal app_exc
async with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
-
coro
函数用于调用下一个中间件或应用程序,并捕获任何异常。
启动任务组
task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)
- 启动两个任务:等待响应发送后关闭接收流和调用下一个中间件或应用程序。
处理响应消息
try:
message = await recv_stream.receive()
info = message.get("info", None)
if message["type"] == "http.response.debug" and info is not None:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
- 接收响应消息并处理,如果没有响应返回,则抛出异常。
body_stream
生成器
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if app_exc is not None:
raise app_exc
-
body_stream
生成器用于生成响应体的字节流。
创建响应对象
response = _StreamingResponse(
status_code=message["status"], content=body_stream(), info=info
)
response.raw_headers = message["headers"]
return response
- 创建一个
_StreamingResponse
对象,并设置响应头和状态码。
调用调度函数
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
response_sent.set()
- 创建任务组并调用调度函数,处理请求并发送响应。
dispatch
方法
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover
-
dispatch
方法是一个抽象方法,子类需要实现该方法以定义自定义的请求处理逻辑。
执行流程
这段代码的执行流程可以分为以下几个步骤:
-
初始化中间件:
-
BaseHTTPMiddleware
类的实例化时,__init__
方法被调用,初始化app
和dispatch_func
属性。
-
-
调用中间件:
- 当一个HTTP请求到达时,
__call__
方法被调用。 - 如果请求类型不是HTTP(例如WebSocket),则直接调用下一个中间件或应用程序。
- 创建一个
anyio.Event
对象response_sent
,用于在响应发送后通知其他任务。
- 当一个HTTP请求到达时,
-
定义内部函数
call_next
:-
call_next
函数用于调用下一个中间件或应用程序,并处理请求和响应。 - 创建两个内存对象流
send_stream
和recv_stream
,用于在任务之间传递消息。
-
-
定义内部函数
receive_or_disconnect
:-
receive_or_disconnect
函数用于接收请求消息或断开连接。 - 如果响应已经发送,则返回断开连接消息,即
type
为http.disconnect
。 - 使用
anyio.create_task_group
创建任务组,并启动两个任务:等待响应发送和接收请求消息。
-
-
定义内部函数
close_recv_stream_on_response_sent
:- 等待响应发送后关闭接收流。
-
定义内部函数
send_no_error
:- 发送消息到发送流,如果发送流已经关闭,则忽略错误。
-
定义内部函数
coro
:-
coro
函数用于调用下一个中间件或应用程序,并捕获任何异常。
-
-
启动任务组:
- 启动两个任务:等待响应发送后关闭接收流和调用下一个中间件或应用程序。
-
处理响应消息:
- 接收响应消息并处理,如果没有响应返回,则抛出异常。
-
定义生成器
body_stream
:-
body_stream
生成器用于生成响应体的字节流。
-
-
创建响应对象:
- 创建一个
_StreamingResponse
对象,并设置响应头和状态码。
- 创建一个
-
调用调度函数:
- 创建任务组并调用调度函数,处理请求并发送响应。
-
设置响应发送事件:
- 调用
response_sent.set()
,通知其他任务响应已经发送。
- 调用