掘金 后端 ( ) • 2024-05-13 11:33

大家好,这一期我想和大家分享一个OOP编程的高效神器:attrs库

这可能是 Python 面向对象编程的最佳实践。

为什么需要attrs

在编写大型项目时,特别是在开发和维护大型项目时,你可能会发现编写 Python 类很繁琐。

我们经常需要添加构造函数、表示方法、比较函数等。这些函数很麻烦,而这正是语言应该透明地处理的

之前的文章,我们介绍过在Python 3.7(PEP 557)后引入一个新功能是装饰器@dataclass帮助我们优雅的处理这一系列问题。今天我们继续介绍另一个能打的库:attrs

attrs统一了类属性描述,让代码更加简洁可读。

如果要写一个完整的类需要加上很多内置方法的实现。这会导致大量的重复工作,而且代码会非常冗长。

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
from typing import Any


class Coordinate:
    def __init__(self, x: Any, y: Any) -> None:
        self.x = x  # 设置横坐标
        self.y = y  # 设置纵坐标

    def __repr__(self) -> str:
        return f"Coordinate(x={self.x}, y={self.y})"  # 返回对象的字符串表示形式

    def __eq__(self, other: Any) -> bool:
        # 检查另一个对象是否是相同类型的Coordinate
        if isinstance(other, self.__class__):
            # 比较两个对象的横纵坐标是否相等
            return (self.x, self.y) == (other.x, other.y)
        else:
            return NotImplemented

    def __ne__(self, other: Any) -> bool:
        result = self.__eq__(other)
        if result is NotImplemented:
            return NotImplemented
        else:
            # 返回两个对象的相等性的否定
            return not result

    def __lt__(self, other: Any) -> bool:
        # 检查另一个对象是否是相同类型的Coordinate
        if isinstance(other, self.__class__):
            # 比较两个对象的横纵坐标的大小关系
            return (self.x, self.y) < (other.x, other.y)
        else:
            return NotImplemented

    def __le__(self, other: Any) -> bool:
        if isinstance(other, self.__class__):
            return (self.x, self.y) <= (other.x, other.y)
        else:
            return NotImplemented

    def __gt__(self, other: Any) -> bool:
        if isinstance(other, self.__class__):
            return (self.x, self.y) > (other.x, other.y)
        else:
            return NotImplemented

    def __ge__(self, other: Any) -> bool:
        if isinstance(other, self.__class__):
            return (self.x, self.y) >= (other.x, other.y)
        else:
            return NotImplemented

    def __hash__(self) -> int:
        # 返回对象的哈希值,用于将对象存储在散列数据结构中
        return hash((self.__class__, self.x, self.y))

添加测试代码:

if __name__ == '__main__':
    # 创建坐标对象
    coord1 = Coordinate(3, 4)
    coord2 = Coordinate(5, 6)

    # 打印对象的字符串表示形式
    print(coord1)  # 输出: Coordinate(x=3, y=4)
    print(coord2)  # 输出: Coordinate(x=5, y=6)

    # 比较坐标对象是否相等
    print(coord1 == coord2)  # 输出: False
    print(coord1 != coord2)  # 输出: True

    # 比较坐标对象的大小关系
    print(coord1 < coord2)  # 输出: True
    print(coord1 <= coord2)  # 输出: True
    print(coord1 > coord2)  # 输出: False
    print(coord1 >= coord2)  # 输出: False

    # 将坐标对象用作字典的键
    coord_dict = {coord1: 'Point 1', coord2: 'Point 2'}
    print(coord_dict)  # 输出: {Coordinate(x=3, y=4): 'Point 1', Coordinate(x=5, y=6): 'Point 2'}

    # 使用集合去除重复的坐标对象
    coords = [Coordinate(1, 2), Coordinate(3, 4), Coordinate(1, 2)]
    unique_coords = set(coords)
    print(unique_coords)  # 输出: {Coordinate(x=1, y=2), Coordinate(x=3, y=4)}

聪明的你很快想到可以像下面这样优化下:

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
from typing import Any, Tuple
from functools import total_ordering


@total_ordering
class Coordinate:
    def __init__(self, x: Any, y: Any) -> None:
        self.x = x  # 设置横坐标
        self.y = y  # 设置纵坐标

    def __repr__(self) -> str:
        return f"Coordinate(x={self.x}, y={self.y})"  # 返回对象的字符串表示形式

    def __eq__(self, other: Any) -> bool:
        if isinstance(other, self.__class__):
            return (self.x, self.y) == (other.x, other.y)
        return NotImplemented

    def __lt__(self, other: Any) -> bool:
        if isinstance(other, self.__class__):
            return (self.x, self.y) < (other.x, other.y)
        return NotImplemented

    def __hash__(self) -> int:
        return hash((self.__class__, self.x, self.y))

然后你想起数据类装饰器@dataclass可能更优雅,不过今天我们用attrs库来实现:

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

from typing import Any
import attr


@attr.s(hash=True)
class Coordinate:
    x: Any = attr.ib()  # 横坐标
    y: Any = attr.ib()  # 纵坐标

attrs只需要简单的描述就可以帮你完成上面一大堆的功能,非常有用。

安装

pip install attrs cattrs

初识attrs库

官方的介绍是这样的:

attrs 是一个 Python 包,它将通过将您从实现对象协议(又名 dunder 方法)的苦差事中解脱出来,带回编写类的乐趣。自 2020 年以来,受到 NASA 火星任务的信任!

它的主要目标是帮助您编写简洁正确的软件,而不会减慢您的代码速度。

目前,attrs库也已经有5K star,是非常受欢迎的,而且也在被不断更新维护。

github地址:https://github.com/python-attrs/attrs

attrs的工作原理是使用attrs.defineattr.s装饰一个类,然后使用attrs.field、attr.ib类型注解定义类的属性。

从 21.3.0 版开始,attrs包含两个顶级包名称:

  • 经典的attr为古老的attr.sattr.ib提供动力。

  • 较新的attrs只包含大多数现代API,并依赖attrs.defineattrs.field来定义类。此外,它还提供了一些默认值更好的attr API(如 attrs.asdict)

推荐优先从较新的attrs导入

另外,attrs库经典的attr API定义了一些别名,自己使用的时候或看别人代码的时候注意一下即可:

attrs库介绍

基本用法与默认参数

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
from loguru import logger
from attrs import define, field, Factory


@define
class Color(object):
    r = field(type=int, default=0)
    g = field(type=int, default=0)
    b = field(type=int, default=0)

    # 可变对象作为默认值
    mutable = field(type=list, default=Factory(list))
    mutable2 = field(type=list, default=Factory(lambda: [1, 2, 3]))
    mutable3 = field(factory=list)


if __name__ == '__main__':
    color = Color(250, 255, 255, mutable=[1, 2, 3])
    logger.info(color)

    color2 = Color()
    logger.info(color2)

输出结果为:

2024-02-18 19:42:08.012 | INFO     | __main__:<module>:23 - Color(r=250, g=255, b=255, mutable=[1, 2, 3], mutable2=[1, 2, 3], mutable3=[])
2024-02-18 19:42:08.012 | INFO     | __main__:<module>:26 - Color(r=0, g=0, b=0, mutable=[], mutable2=[1, 2, 3], mutable3=[])

一切都显得那么简洁。一个字,爽!

不参与初始化

如果某个类的特定属性不希望在初始化过程中被设置,例如希望直接将其设置为固定的初始值且保持不变,我们可以将该属性的 init 参数设置为 False

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

from attrs import define, field
from loguru import logger


@define
class Color(object):
    r = field(type=int, default=0, init=False)
    g = field(type=int, default=0)
    b = field(type=int, default=0)


if __name__ == '__main__':
    color = Color(255, 255)
    logger.info(color)

输出结果为:

2024-02-18 22:46:44.228 | INFO     | __main__:<module>:19 - Color(r=0, g=255, b=255)

派生属性

如何拥有依赖于其他属性的属性?

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
import datetime

import attr
from attrs import field
from loguru import logger


@attr.s
class Data(object):
    dt: datetime.datetime = attr.ib(factory=datetime.datetime)
    dt_30_later: datetime.datetime = field(init=False)
    dt_30_later2 = field(init=False)

    # 方法一
    def __attrs_post_init__(self):
        self.dt_30_later = self.dt + datetime.timedelta(days=30)

    # 方法二
    @dt_30_later2.default
    def _dt_30_later2(self):
        return self.dt + datetime.timedelta(days=30)


def serialize(inst, attribute, value):
    if isinstance(value, datetime.datetime):
        return value.isoformat()
    return value


json_data = attr.asdict(
    Data(datetime.datetime(2024, 2, 18, 21, 46)),
    value_serializer=serialize)

logger.info(json_data)

输出结果为:

2024-02-18 21:50:52.516 | INFO     | __main__:<module>:36 - {'dt': '2024-02-18T21:46:00', 'dt_30_later': '2024-03-19T21:46:00', 'dt_30_later2': '2024-03-19T21:46:00'}

强制关键字

from attr import attrs, attrib
from loguru import logger


@attrs
class Color(object):
    r = attrib(type=int, default=0)
    g = attrib(type=int, default=0)
    b = attrib(type=int, default=0, kw_only=True)


if __name__ == '__main__':
    color = Color(250, 255, b=255)
    logger.info(color)

    color2 = Color()
    logger.info(color2)

输出结果为:

2024-02-18 15:25:11.078 | INFO     | __main__:<module>:19 - Color(r=250, g=255, b=255)
2024-02-18 15:25:11.078 | INFO     | __main__:<module>:22 - Color(r=0, g=0, b=0)

我们把b参数设置为关键字参数,若传入则必须使用关键字的名字来传入,否则会报错

元数据

attrs还支持在类和属性上添加元数据,为类和属性添加更多的信息,使代码更加自文档化。

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
import attrs


@attrs.define
class Product:
    name = attrs.field(metadata={"description": "Product name"})
    price = attrs.field(metadata={"description": "Product price"})


# 获取属性的元数据
name_description = attrs.fields(Product).name.metadata["description"]
print(f"Name description: {name_description}")

构建不可变类

通过设置frozen=True,可以创建不可变的实例。

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

from attr import attrib, attrs
from loguru import logger


@attrs(frozen=True)
class Color(object):
    r = attrib(type=int, default=0)
    g = attrib(type=int, default=0)
    b = attrib(type=int, default=0, kw_only=True)


if __name__ == '__main__':
    color = Color(250, 255, b=255)
    logger.info(color)

    # color.r = 100  # 加入 frozen 参数可以让类初始化之后不可改变,强行改变会报错

比较

import attr


@attr.attrs
class Point:
    x = attr.ib()
    y = attr.ib()


point1 = Point(1, 2)
point2 = Point(3, 4)
point3 = Point(1, 2)

print(point1 == point2)  # 输出: False
print(point1 != point2)  # 输出: True
print(point1 == point3)  # 输出: True

print(point1 < point2)  # 输出: True
print(point1 <= point2)  # 输出: True
print(point1 > point2)  # 输出: False
print(point1 >= point2)  # 输出: False

由于使用了attrs,相当于我们定义的类已经有了__eq__、__ne__、__lt__、__le__、__gt__、__ge__这几个方法,所以我们可以直接使用比较符来对类和类之间进行比较。

在比较对象时,attrs 库内部实现是将类的各个属性转换为元组来进行比较。例如,对于 Point(1, 2) < Point(3, 4) 这样的比较,实际上是在比较两个元组 (1, 2)(3, 4)。关于元组之间的比较逻辑,具体细节可以参考官方文档:https://docs.python.org/3/library/stdtypes.html#comparisons。

验证数据

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

import attr
from loguru import logger


@attr.s
class Book:
    title = attr.ib(default="Unknown", validator=attr.validators.instance_of(str))
    pages = attr.ib(default=0, validator=attr.validators.instance_of(int))


# 创建对象
book = Book(title="三国演义", pages=200)

# 输出对象信息
logger.info(book)

book2 = Book(title="三国演义", pages="200")

输出结果为:

2024-02-18 15:43:23.098 | INFO     | __main__:<module>:20 - Book(title='三国演义', pages=200)
Traceback (most recent call last):
  File "E:/projects/mukewang/python_and_go/约瑟夫.py", line 22, in <module>
    book2 = Book(title="三国演义", pages="200")
  File "<attrs generated init 3ea315b0de8c75a1cfeeeb1010a6df97094df790>", line 6, in __init__
  File "E:\ENV\py3.8_blog\lib\site-packages\attr\validators.py", line 22, in __call__
    raise TypeError(
TypeError: ("'pages' must be <class 'int'> (got '200' that is a <class 'str'>).", Attribute(name='pages', default=0, validator=<instance_of validator for type <class 'int'>>, repr=True, cmp=True, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=False), <class 'int'>, '200')

支持对初始化的数据进行校验。

# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

import attr


@attr.s
class TrafficLight:
    color = attr.ib(validator=attr.validators.in_({"red", "yellow", "green"}), kw_only=True)


# 创建对象
traffic_light = TrafficLight(color="red")

# 这里不会报错
traffic_light.color = "blue"

@attr.s修改属性时不会报错。

使用较新的attrs.define,修改属性时也会校验。

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
import attrs


@attrs.define
class TrafficLight:
    color = attrs.field(default="red", validator=attrs.validators.in_({"red", "yellow", "green"}), kw_only=True)


# 创建对象
traffic_light = TrafficLight()

# 修改属性,会触发验证器并引发异常
try:
    traffic_light.color = "blue"
except ValueError as e:
    print(f"Validation Error: {e}")

输出结果为:

Validation Error: ("'color' must be in {'red', 'green', 'yellow'} (got 'blue')", Attribute(name='color', default='red', validator=<in_ validator with options {'red', 'green', 'yellow'}>, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=None, converter=None, kw_only=True, inherited=False, on_setattr=None, alias='color'), {'red', 'green', 'yellow'}, 'blue')

attrs库中,旧版api中验证器通常在对象初始化时触发,而不是在属性赋值时触发,因此需要额外处理。

import attr


def validate_color(instance, attribute, value):
    if value not in {"red", "yellow", "green"}:
        raise ValueError(f"Invalid color: {value}")


@attr.s
class TrafficLight:
    _color = attr.ib(default="red")

    @property
    def color(self):
        return self._color

    @color.setter
    def color(self, value):
        validate_color(self, self.color, value)
        self._color = value


# 创建对象
traffic_light = TrafficLight()

# 修改属性,会触发验证器并引发异常
try:
    traffic_light.color = "blue"
except ValueError as e:
    print(f"Validation Error: {e}")

自定义验证器的两种方法:

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

import attr
from loguru import logger


def pages_validator(instance, attribute, value):
    if value > 300:
        raise ValueError("pages must be greater than 300")
    return value


@attr.s
class Book:
    title = attr.ib(default="Unknown", validator=attr.validators.instance_of(str))
    pages = attr.ib(default=0, validator=[attr.validators.instance_of(int), pages_validator])  # 校验一
    price = attr.ib(default=0.0, validator=attr.validators.instance_of(float))

    # 校验二
    @title.validator
    def check_title(self, attribute, value):
        if value == "Unknown":
            raise ValueError("title cannot be 'Unknown'")


# 验证自定义验证器
try:
    book3 = Book(title="三国演义", pages=500)

    logger.info(book3)
except Exception as e:
    logger.error(e)

try:
    book3 = Book(pages=500)

    logger.info(book3)
except Exception as e:
    logger.error(e)

输出结果为:

在自定义Validator时,有三个固定的参数:

  • instance(或self):表示类对象
  • attribute:表示属性名
  • value:表示属性值

这三个参数在类初始化时被固定地传递给了 Validator。因此,Validator 在接收到这三个值后,就能够进行相应的判断。

另外还有其他的一些 Validator,比如与或运算、可执行判断、可迭代判断等等,可以参考官方文档:https://www.attrs.org/en/stable/api.html#validators。

转换器

有时候,我们可能会不小心传入一些格式不太标准的数据,例如本来应该是整数类型的数字 100,却传入了字符串类型的 "100"。在这种情况下,直接抛出错误可能不太友好。因此,我们可以设置一些转换器来增强容错机制,例如自动将字符串转换为数字等。让我们看一个实例:

import attr


def to_int(value):
    return int(value)


@attr.s
class Point:
    x = attr.ib(converter=to_int)
    y = attr.ib(converter=to_int)


# 使用示例
point1 = Point(1, 2)
print(point1) 

point2 = Point("3", "4")
print(point2)  

point3 = Point("5", "hello")
print(point3)  

输出结果为:

Point(x=1, y=2)
Point(x=3, y=4)
Traceback (most recent call last):
  File "E:/projects/mukewang/python_and_go/约瑟夫.py", line 26, in <module>
    point3 = Point("5", "hello")
  File "<attrs generated init bc20f05ee98d8a0eaee1da898b5665773d9e94d7>", line 3, in __init__
  File "E:/projects/mukewang/python_and_go/约瑟夫.py", line 10, in to_int
    return int(value)
ValueError: invalid literal for int() with base 10: 'hello'

attrs允许你在创建类时自动修改或转换类的字段。其主要目的是根据属性类型自动为属性添加转换器。

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

import datetime

import attrs
from attrs import frozen
from loguru import logger


def auto_convert(cls, fields):
    results = []
    for field in fields:
        logger.info(field)
        if field.converter is not None:
            results.append(field)
            continue
        if field.type in {datetime.datetime, 'datetime'}:
            converter = (lambda d: datetime.datetime.fromisoformat(d) if isinstance(d, str) else d)
        else:
            converter = None
        results.append(field.evolve(converter=converter))
    return results


@frozen(field_transformer=auto_convert)
class Data:
    a: int
    b: str
    c: datetime


logger.info("-----" * 5)
from_json = {"a": 3, "b": "spam", "c": "2020-05-04T13:37:00"}
logger.info(Data(**from_json))

logger.info(Data(a=3, b='spam', c=datetime.datetime(2020, 5, 4, 13, 37)))

logger.info(attrs.asdict(Data(a=3, b='spam', c=datetime.datetime(2020, 5, 4, 13, 37))))
logger.info(attrs.astuple(Data(c=datetime.datetime(2020, 5, 4, 13, 37), a=3, b='spam')))

输出结果为:

2024-02-18 23:12:00.571 | INFO     | __main__:auto_convert:16 - Attribute(name='a', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=<class 'int'>, converter=None, kw_only=False, inherited=False, on_setattr=None, alias=None)
2024-02-18 23:12:00.571 | INFO     | __main__:auto_convert:16 - Attribute(name='b', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=<class 'str'>, converter=None, kw_only=False, inherited=False, on_setattr=None, alias=None)
2024-02-18 23:12:00.571 | INFO     | __main__:auto_convert:16 - Attribute(name='c', default=NOTHING, validator=None, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=<module 'datetime' from 'd:\\installdir\\python3.8.0\\Lib\\datetime.py'>, converter=None, kw_only=False, inherited=False, on_setattr=None, alias=None)
2024-02-18 23:12:00.571 | INFO     | __main__:<module>:35 - -------------------------
2024-02-18 23:12:00.571 | INFO     | __main__:<module>:37 - Data(a=3, b='spam', c='2020-05-04T13:37:00')
2024-02-18 23:12:00.571 | INFO     | __main__:<module>:39 - Data(a=3, b='spam', c=datetime.datetime(2020, 5, 4, 13, 37))
2024-02-18 23:12:00.571 | INFO     | __main__:<module>:41 - {'a': 3, 'b': 'spam', 'c': datetime.datetime(2020, 5, 4, 13, 37)}
2024-02-18 23:12:00.587 | INFO     | __main__:<module>:42 - (3, 'spam', datetime.datetime(2020, 5, 4, 13, 37))

序列转换

在许多情况下,我们经常需要在JSON等字符串序列和对象之间进行转换,特别是在编写REST API和数据库交互时。

attrs为我们提供了asdictastuple方法用于序列化。

默认情况下,asdict方法将实例中的所有属性转换为字典的键值对,但是可以通过一些参数来自定义转换行为,比如使用filter参数来指定要包含或排除的属性,使用value_serializer参数来指定自定义的值序列化函数。

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
import attr
import attrs
from loguru import logger


@attrs.define
class UserInfo(object):
    users = attr.ib()


@attrs.define
class User(object):
    email = attr.ib()
    name = attr.ib()


# including only name and not email

json_data = attrs.asdict(UserInfo([User("[email protected]", "Lee"),
                                   User("[email protected]", "Rachel")]),
                         filter=lambda _attr, value: _attr.name != "email")

logger.info(json_data)


# ---

@attrs.define
class UserInfo(object):
    name = attrs.field()
    password = attrs.field()
    age = attrs.field()


# excluding attributes
logger.info(
    attrs.asdict(UserInfo("Marco", "abc@123", 22), filter=attrs.filters.exclude(attrs.fields(UserInfo).password, int)))


# ---
@attr.s
class Coordinates(object):
    x = attr.ib()
    y = attr.ib()
    z = attr.ib()


# inclusing attributes

logger.info(attrs.asdict(Coordinates(20, "5", 3),
                         filter=attr.filters.include(int)))

logger.info(attrs.astuple(Coordinates(20, "5", 3),
                          filter=attr.filters.include(int)))

输出结果为:

2024-02-18 23:25:21.529 | INFO     | __main__:<module>:27 - {'users': [{'name': 'Lee'}, {'name': 'Rachel'}]}
2024-02-18 23:25:21.529 | INFO     | __main__:<module>:40 - {'name': 'Marco'}
2024-02-18 23:25:21.529 | INFO     | __main__:<module>:55 - {'x': 20, 'z': 3}
2024-02-18 23:25:21.529 | INFO     | __main__:<module>:58 - (20, 3)

序列化时间:

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
import datetime

import attr
from loguru import logger


@attr.s
class Data(object):
    dt: datetime.datetime = attr.ib(factory=datetime.datetime)


def serialize(inst, attribute, value):
    if isinstance(value, datetime.datetime):
        return value.isoformat()
    return value


json_data = attr.asdict(
    Data(datetime.datetime(2020, 5, 4, 13, 37)),
    value_serializer=serialize)

logger.info(json_data)

输出结果为:

2024-02-18 21:14:20.875 | INFO     | __main__:<module>:26 - {'dt': '2020-05-04T13:37:00'}

尽管attrs库提供了序列化的能力,但是我们一般习惯使用cattrs库。

attrs更侧重于创建Python类,并提供了一些辅助方法来处理这些类的实例,而cattrs则更专注于对象的序列化和反序列化操作。

cattrs库的导入名称稍有不同,称为cattr。它提供了两个主要的方法:structureunstructure。这两个方法是互补的,对于类的序列化和反序列化提供了很好的支持。

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

from typing import Dict, Any
from attr import attrs, attrib
import cattr


@attrs
class Point:
    x: int = attrib(default=0)
    y: int = attrib(default=0)


def drop_non_attrs(d: Dict[str, Any], type_: type) -> Dict[str, Any]:
    if not isinstance(d, dict):
        return d
    attrs_attrs = getattr(type_, '__attrs_attrs__', None)
    if attrs_attrs is None:
        raise ValueError(f'type {type_} is not an attrs class')
    attrs_set = {attr.name for attr in attrs_attrs}
    return {key: val for key, val in d.items() if key in attrs_set}


def structure(d: Dict[str, Any], type_: type) -> Any:
    return cattr.structure(drop_non_attrs(d, type_), type_)


json_data = {'x': 1, 'y': 2, 'z': 3}
print(structure(json_data, Point))

对时间datetime转换的时候进行的处理:

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51

import datetime
from attr import attrs, attrib
import cattr
from loguru import logger

TIME_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'


@attrs
class Event(object):
    happened_at = attrib(type=datetime.datetime)


cattr.register_unstructure_hook(datetime.datetime, lambda dt: dt.strftime(TIME_FORMAT))
cattr.register_structure_hook(datetime.datetime,
                              lambda string, _: datetime.datetime.strptime(string, TIME_FORMAT))

event = Event(happened_at=datetime.datetime(2024, 2, 18))
logger.info(f'event: {event}')
json = cattr.unstructure(event)
logger.info(f'json: {json}')
event = cattr.structure(json, Event)
logger.info(f'Event: {event}')

在这里,我们为 datetime 类型注册了两个钩子。在序列化时,我们调用 strftime 方法将其转换为字符串;在反序列化时,我们调用 strptime 将其转换回 datetime 类型。

嵌套处理

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51


from loguru import logger
from attr import attrs, attrib
from typing import List
from cattr import structure, unstructure


@attrs
class Point:
    x = attrib(type=int, default=0)
    y = attrib(type=int, default=0)


@attrs
class Color:
    r = attrib(default=0)
    g = attrib(default=0)
    b = attrib(default=0)


@attrs
class Line:
    color = attrib(type=Color)
    points = attrib(type=List[Point])


if __name__ == '__main__':
    try:
        line = Line(color=Color(), points=[Point(i, i) for i in range(5)])
        logger.info(f'Created Line object: {line}')

        json_data = unstructure(line)
        logger.info(f'Serialized JSON: {json_data}')

        line = structure(json_data, Line)
        logger.info(f'Deserialized Line object: {line}')
    except Exception as e:
        logger.error(f'An error occurred: {e}')

输出结果为:

2024-02-18 22:41:46.668 | INFO     | __main__:<module>:35 - Created Line object: Line(color=Color(r=0, g=0, b=0), points=[Point(x=0, y=0), Point(x=1, y=1), Point(x=2, y=2), Point(x=3, y=3), Point(x=4, y=4)])
2024-02-18 22:41:46.668 | INFO     | __main__:<module>:38 - Serialized JSON: {'color': {'r': 0, 'g': 0, 'b': 0}, 'points': [{'x': 0, 'y': 0}, {'x': 1, 'y': 1}, {'x': 2, 'y': 2}, {'x': 3, 'y': 3}, {'x': 4, 'y': 4}]}
2024-02-18 22:41:46.668 | INFO     | __main__:<module>:41 - Deserialized Line object: Line(color=Color(r=0, g=0, b=0), points=[Point(x=0, y=0), Point(x=1, y=1), Point(x=2, y=2), Point(x=3, y=3), Point(x=4, y=4)])

可以看到,我们轻松地将对象转换为了 JSON 表示,并且同样轻松地将其转换回对象。

使用场景示例

以下是一些实际的使用场景示例代码:

Web 应用中的表单验证:

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
import attrs
from attrs import validators


@attrs.define
class UserRegistrationForm:
    username = attrs.field(validator=validators.instance_of(str))
    email = attrs.field(validator=validators.instance_of(str))
    password = attrs.field(validator=validators.instance_of(str))


# 使用示例
form_data = {"username": "alice", "email": "[email protected]", "password": "123456"}
form = UserRegistrationForm(**form_data)

数据传输对象 (DTO):

#!usr/bin/env python
# -*- coding:utf-8 _*-
# __author__:lianhaifeng
# __time__:2024/2/14 18:51
import cattr

import attrs
from datetime import datetime

from loguru import logger


@attrs.define
class OrderDTO:
    order_id = attrs.field()
    customer_name = attrs.field()
    total_amount = attrs.field()

    # 添加一个时间戳属性
    created_at = attrs.field(default=datetime.now())


# 定义 datetime 的序列化和反序列化函数
def serialize_datetime(dt):
    return dt.strftime('%Y-%m-%d %H:%M:%S')


def deserialize_datetime(dt_str, _: dict):
    return datetime.strptime(dt_str, '%Y-%m-%d %H:%M:%S')


# 创建一个包含 datetime 序列化和反序列化函数的转换器
converter = cattr.Converter()
converter.register_structure_hook(datetime, deserialize_datetime)
converter.register_unstructure_hook(datetime, serialize_datetime)
# 创建一个订单DTO对象
order_data = {"order_id": "123456", "customer_name": "Alice", "total_amount": 100}
order_dto = converter.structure(order_data, OrderDTO)

# 打印订单DTO对象的属性
logger.info(f"Order ID: {order_dto.order_id}")
logger.info(f"Customer Name: {order_dto.customer_name}")
logger.info(f"Total Amount: {order_dto.total_amount}")
logger.info(f"Created At: {order_dto.created_at}")

# 将订单DTO对象转换为字典
# order_dict = {
#     "order_id": order_dto.order_id,
#     "customer_name": order_dto.customer_name,
#     "total_amount": order_dto.total_amount,
#     "created_at": order_dto.created_at.strftime("%Y-%m-%d %H:%M:%S")
# }

order_dict = converter.unstructure(order_dto)
logger.info(f"Order Dictionary: {order_dict}")

输出结果为:

2024-02-18 17:30:04.563 | INFO     | __main__:<module>:27 - Order ID: 123456
2024-02-18 17:30:04.563 | INFO     | __main__:<module>:28 - Customer Name: Alice
2024-02-18 17:30:04.563 | INFO     | __main__:<module>:29 - Total Amount: 100
2024-02-18 17:30:04.563 | INFO     | __main__:<module>:30 - Created At: 2024-02-18 17:30:04.561213
2024-02-18 17:30:04.563 | INFO     | __main__:<module>:39 - Order Dictionary: {'order_id': '123456', 'customer_name': 'Alice', 'total_amount': 100, 'created_at': '2024-02-18 17:30:04'}

图书出版示例

#!usr/bin/env python
# -*- coding:utf-8 _*-

from typing import List
import attrs
from loguru import logger


@attrs.define(auto_attribs=True)
class Book:
    title: str
    author: str
    pages: int
    price: float

    def is_expensive(self) -> bool:
        return self.price >= 20


@attrs.define(auto_attribs=True)
class BookCollection:
    books: List[Book] = attrs.field(factory=list)

    def add_book(self, book: Book):
        self.books.append(book)

    def remove_book(self, book: Book):
        self.books.remove(book)

    def get_expensive_books(self) -> List[Book]:
        return [book for book in self.books if book.is_expensive()]


# 创建Book的实例
book1 = Book(title="流畅的Python", author="拉马略", pages=523, price=24.99)
book2 = Book(title="Coding with Python", author="Another Author", pages=210, price=19.99)

# 创建BookCollection的实例,并添加书籍
collection = BookCollection()
collection.add_book(book1)
collection.add_book(book2)

# 移除一本书
collection.remove_book(book2)

# 查找所有昂贵的图书
expensive_books = collection.get_expensive_books()
logger.info(f"昂贵的图书: {expensive_books}")


@attrs.define
class User:
    name = attrs.field(type=str)
    age = attrs.field(converter=int)
    email = attrs.field(type=str)

    @age.validator
    def check_age(self, attribute, value):
        if value < 18:
            raise ValueError("User must be at least 18 years old")

    @email.validator
    def check_email(self, attribute, value):
        if "@" not in value:
            raise ValueError("Invalid email address")


# 创建User类的实例,注意这里故意创建一个非法的实例来演示验证功能
try:
    user = User(name="John Doe", age=17, email="[email protected]")
except ValueError as e:
    logger.error(e)


@attrs.define(auto_attribs=True)
class Publisher:
    name: str
    founded: int
    location: str

    def publish(self, book: Book):
        logger.info(f"{self.name} 出版了:【{book.title}】")


@attrs.define(auto_attribs=True)
class Review:
    content: str
    book: Book
    score: int

    def is_positive(self) -> bool:
        return self.score > 3


@attrs.define(auto_attribs=True)
class AuthorProfile:
    name: str
    genre: str
    books_written: List[Book] = attrs.field(factory=list)

    def write_book(self, title: str, pages: int, price: float) -> Book:
        book = Book(title=title, author=self.name, pages=pages, price=price)
        self.books_written.append(book)
        return book


# 创建一个出版商实例
publisher = Publisher(name="人民邮电出版社", founded=2023, location="中国")

# 作者创建书籍并由出版商发布
author_profile = AuthorProfile(name="唐诗三百首", genre="Tech")
new_book = author_profile.write_book("流畅的Python", pages=300, price=29.99)
publisher.publish(new_book)

# 添加一些书评
review1 = Review(content="不错的书!", book=new_book, score=8)
review2 = Review(content="烂书..", book=new_book, score=1)

# 假设我们需要展示所有正面的书评
positive_reviews = [review for review in [review1, review2] if review.is_positive()]

for review in positive_reviews:
    logger.info(f"- {review.content}")

输出结果为:

2024-02-18 22:42:45.841 | INFO     | __main__:<module>:48 - 昂贵的图书: [Book(title='流畅的Python', author='拉马略', pages=523, price=24.99)]
2024-02-18 22:42:45.842 | ERROR    | __main__:<module>:72 - User must be at least 18 years old
2024-02-18 22:42:45.845 | INFO     | __main__:publish:82 - 人民邮电出版社 出版了:【流畅的Python】
2024-02-18 22:42:45.845 | INFO     | __main__:<module>:123 - - 不错的书!

小结

本节介绍了如何利用attrscattrs两个库来实现Python的面向对象编程。有了这两个库的支持,Python的面向对象编程变得更加简单易行。

从表面上看,attrs可能会让你联想到数据类(事实上,数据类@dataclassattrs的后代)。实际上,它的功能更多,也更灵活。例如,它允许你定义NumPy数组的特殊处理方法以进行相等检查,允许更多方法插入初始化过程,并允许使用调试器逐步检查生成的方法。

强烈建议优先使用较新的attrs API,如attrs.defineattrs.field等。

更多attrs库的使用方法请浏览官方文档!

最后

如果你觉得文章还不错,请大家点赞、分享、关注下,因为这将是我持续输出更多优质文章的最强动力!

参考

https://www.attrs.org/en/stable/init.html#hooking-yourself-into-initialization

https://catt.rs/en/stable/index.html#