掘金 阅读 ( ) • 2024-05-05 11:06

theme: smartblue

一、前言

最近在用 Rust 刷算法题的过程中遇到了几道二叉树相关的问题,题目给的输入基本都是二叉树层序遍历的结果(如下图所示),在本地测试的话需要根据其自行构造二叉树。

Clip_2024-05-04_12-24-22.png

这部分的具体代码实现很常见,因此不再详细介绍。最后的形式如下面代码所示,其中传入参数的类型为Vec<Option<T>>,即二叉树层序遍历的结果。

TreeNode::from(vec![Some(1), Some(5), Some(3), None, Some(4), Some(10), Some(6), Some(9), Some(2)])

可以看到参数的形式比输入示例复杂了不少,每次在本地测试都需要手动输入测试用例,略显麻烦。本着省事的原则(懒是第一推动力),就想通过某些方法实现可以直接根据测试用例的输入构造二叉树。

但是测试用例给的输入是 JSON 格式的数据,Rust 本身不支持类似的语法,虽然可以通过 serde 等库来对 JSON 进行反序列化,但是其难免会引入运行时开销。而宏编程可以在不引入运行时开销的前提下完成对自定义语法的扩展,完美符合要求。

// 测试用例的输入
// [1, 5, 3, null, 4, 10, 6, 9, 2]

// 使用 serde 进行 JSON 反序列化
let vals: Vec<Option<i32>> = serde_json::from_str("[1, 5, 3, null, 4, 10, 6, 9, 2]").unwrap();
let root = TreeNode::from(vals);

// 使用类函数宏
let root = tree![1, 5, 3, null, 4, 10, 6, 9, 2];

二、Rust 中的宏编程

参考 Rust 圣经中关于宏的介绍:宏 - Rust 程序设计语言 简体中文版 (kaisery.github.io)

正如上面所说,宏编程可以在不引入运行时开销的同时创建自定义语法扩展,其本质上是一种在编译阶段通过代码生成代码的方式,能帮助简化代码和减少大量重复的代码,提升代码可读性和可维护性。

在 Rust 中,宏有声明宏(Declarative Macros)和过程宏(Procedural Macros)两种类型,其中声明宏通过模式匹配进行定义,而过程宏通过函数进行定义。而过程宏又包含派生宏(Derive Macros)、类属性宏(Attribute-like Macros)和类函数宏(Function-like Macros)三种。

1. 声明宏

声明宏是 Rust 中最常用的宏类型,比如常见的println!vec!都属于声明宏。声明宏通过macro_rules!定义,其接收一个表达式,通过对表达式的结果进行模式匹配来执行对应代码。一个vec!宏定义的简化版本如下面代码所示:

#[macro_export]
macro_rules! vec {
    ( $( $x:expr ),* ) => {
        {
            let mut temp_vec = Vec::new();
            $(
                temp_vec.push($x);
            )*
            temp_vec
        }
    };
}

2. 派生宏和类属性宏

派生宏通常用于为结构体生成 trait 的默认实现,比如常见的#[derive(Debug)]等。在下面这个简单的例子中,用户可以使用#[derive(Hello)]注解它们的类型来得到hello函数的默认实现。而对于派生宏的定义,需要实现一个函数,将结构体代码作为参数输入,其类型为TokenStream,并将生成的扩展代码作为结果返回,其类型同样为TokenStream

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};

#[proc_macro_derive(Hello)]
pub fn hello_macro_derive(input: TokenStream) -> TokenStream {
    let ast = parse_macro_input!(input as DeriveInput); // 将输入 TokenStream 转换为抽象语法树
    let ident = ast.ident; // 获取结构体的名称
    let output = quote! {
        impl Hello for #ident {
            pub fn hello() {
                println!("Hello world!");
            }
        }
    }; // 使用 quote! 宏扩展代码
    output.into()
}
trait Hello {
    fn hello();
}

#[derive(Hello)]
struct Pancakes;

// impl Hello for Pancakes {
//     pub fn hello() {
//         println!("Hello world!");
//     }
// }

fn main() {
    Pancakes::hello();
}

类属性宏与派生宏类似,不同的是派生宏只能用于结构体和枚举,而类属性宏还可以用于其它的项,比如函数。类属性宏的定义如下面代码所示:

#[proc_macro_attribute]
pub fn attribute_like_macro(attr: TokenStream, item: TokenStream) -> TokenStream {
    ...
}
#[attribute_like_macro(...)]
fn test() {
    ...
}

3. 类函数宏

类函数宏和声明宏一样在使用上都与函数调用类似,但由于使用函数定义的方法,其比声明宏更为灵活。类函数宏的定义如下面代码所示,获取括号中的代码,并返回希望生成的代码

#[proc_macro]
pub fn func_like_macro(input: TokenStream) -> TokenStream {
    ...
}
func_like_macro!(...);

三、自定义类函数宏

正如上述关于 Rust 中不同类型的宏的介绍,使用类函数宏来解决前言中提到的问题是最为合适的,因此下面介绍如何实现这样一个自定义的类函数宏。

1. 语法分析

抽象语法树(Abstract Syntax Tree, AST)是源代码语法结构的一种抽象表示,以树状数据结构表示编程语言的语法结构。虽然宏编程可以实现自定义语法扩展,但最终还是需要解析成 AST,因此下面先进行语法分析。

这里推荐一个网站:AST explorer,可以将输入的代码解析成 AST,可以方便我们进行语法分析。

Clip_2024-05-04_20-50-25.png

Clip_2024-05-04_21-18-52.png

这里重点看上图中标记的TokenStream部分,这便是类函数宏定义的输入,可以看到列表中的内容被解析成了三种不同类型的 token:Literal(字面量)、 Ident(标识符)和 Punct(标点)。

2. 代码实现

通过语法分析,结合宏扩展的目标,我们发现只需要在遍历TokenStream的过程中对字面量和除null之外的标识符使用Some进行包裹,而将null转换为None。具体代码如下:

#[proc_macro]
pub fn tree(input: TokenStream) -> TokenStream {
    let mut vals = Vec::new();
    let mut is_comma = false;
    for token in input {
        if let TokenTree::Punct(_) = &token {
            if is_comma {
                vals.push(String::from("None")); // 兼容连续出现两个逗号的情况
            }
        } else {
            let token_str = token.to_string();
            let val = if token_str == "None" || token_str == "null" {
                String::from("None")
            } else {
                format!("Some({token})")
            };
            vals.push(val);
            is_comma = false;
        }
    }
    let output = format!("TreeNode::from(vec![{}])", vals.join(", ")); // 拼接生成的代码
    output.parse().unwrap()
}