掘金 后端 ( ) • 2024-04-01 23:41

一个完整的Tasks生命周期

这个时候我们终于可以解答我们在Celery基础架构文末留出来的第二个问题了:

我们定义的task celery 是如何扫描并注册到celery的?

我们知道,在python中,貌似并没有java哪种注解的概念,也没有发现类似于注解扫描的机制,那celery是如何知道我们哪些函数是被@task修饰过并添加到上下文的呢?

在详细分析源码之前,我们需要先搞清楚一个概念,还记得我们在之前的文章中曾分析了task的协议,其中task_id这些源信息是放在header里面的,参数是放在body里面的,这说明了一个非常有趣的事情,celery貌似并没有把函数作为一个单元作为消息体在消息队列中传输,而是只告诉了worker对应的任务的名称以及参数。为什么不把函数作为一整个计算单元传到worker呢,worker拿到代码直接执行不就好了?

celery这样做的好处非常的明显:那就是极大的压缩了消息的体积,因为一整个函数传输的数据量通常来说并不小。这波叫节省数据传输的成本,提高效率。

但缺点也是显而易见的,那就是worker和我们代码形成了比较深的耦合,即worker必须要从我们项目启动,不然无法根据task_name 去找到并执行相关的task的逻辑。

既然我们通常声明一个celery 任务的方式就是给对应的函数加上@task的这个装饰器,那不妨我们就从这里下手。

@Task 装饰器

def task(*args, **kwargs):
    """Deprecated decorator, please use :func:`celery.task`."""
    return current_app.task(*args, **dict({'base': Task}, **kwargs))

点进去一看,非常的简单的,注意,这里的current_app是我们Worker中全局的Celery对象的实例。这个装饰器的实际作用就是把我们的函数包装成了一个Task的代理对象,不然你的函数哪里来的delayapply_ansyc方法,是吧,非常有道理。而这里的task方法,注意,就是我们前面说非常非常重要的Celery类的task方法。再一次把代码贴上来。对应的代码路径: celery.app.base.Celery

def task(self, *args, **opts):

    if USING_EXECV and opts.get('lazy', True):
        # When using execv the task in the original module will point to a
        # different app, so doing things like 'add.request' will point to
        # a different task instance.  This makes sure it will always use
        # the task instance from the current app.
        # Really need a better solution for this :(
        from . import shared_task
        return shared_task(*args, lazy=False, **opts)

    def inner_create_task_cls(shared=True, filter=None, lazy=True, **opts):
        _filt = filter

        def _create_task_cls(fun):
            if shared:
                def cons(app):
                    return app._task_from_fun(fun, **opts)
                cons.__name__ = fun.__name__
                connect_on_app_finalize(cons)
            if not lazy or self.finalized:
                ret = self._task_from_fun(fun, **opts)
            else:
              	# 生成一个代理对象
                # return a proxy object that evaluates on first use
                ret = PromiseProxy(self._task_from_fun, (fun,), opts,
                                    __doc__=fun.__doc__)
                # 放到待处理的队列里面,注意,这部分和connect_on_app_finalize(cons)会各执行一次。
                self._pending.append(ret)
            if _filt:
                return _filt(ret)
            return ret

        return _create_task_cls

    if len(args) == 1:
        if callable(args[0]):
            return inner_create_task_cls(**opts)(*args)
        raise TypeError('argument 1 to @task() must be a callable')
    if args:
        raise TypeError(
            '@task() takes exactly 1 argument ({0} given)'.format(
                sum([len(args), len(opts)])))
    return inner_create_task_cls(**opts)

注意inner_create_task_cls, 最终我们把task包装成了一个这玩意。首先我们需要特别留意两个地方,一个就是connect_on_app_finalize。因为shared默认值是True,所以一定会执行这行代码。点进去看看:

_on_app_finalizers = set()

def connect_on_app_finalize(callback):
    """Connect callback to be called when any app is finalized."""
    _on_app_finalizers.add(callback)
    return callback

到这里我们发现一个惊天小秘密,在celery内部,维护了一个set,这个set的作用现在我们也不知道是啥。但是我们知道他把一个函数给add进去了,cons这个函数。他调用了_task_from_fun方法。点进去看看:

def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
    if not self.finalized and not self.autofinalize:
        raise RuntimeError('Contract breach: app not finalized')
    name = name or self.gen_task_name(fun.__name__, fun.__module__)
    base = base or self.Task
    if name not in self._tasks:
        run = fun if bind else staticmethod(fun)
        task = type(fun.__name__, (base,), dict({
            'app': self,
            'name': name,
            'run': run,
            '_decorated': True,
            '__doc__': fun.__doc__,
            '__module__': fun.__module__,
            '__header__': staticmethod(head_from_fun(fun, bound=bind)),
            '__wrapped__': run}, **options))()
        # for some reason __qualname__ cannot be set in type()
        # so we have to set it here.
        try:
            task.__qualname__ = fun.__qualname__
        except AttributeError:
            pass
        # 注册到app维护的列表里面
        self._tasks[task.name] = task
        # 这里的作用是吧app的各种属性复制到task内部去
        task.bind(self)  # connects task to this app
        """
        ...... 省略一部分不重要的
        """
        if autoretry_for and not hasattr(task, '_orig_run'):

            @wraps(task.run)
            def run(*args, **kwargs):
                try:
                    return task._orig_run(*args, **kwargs)
                except Ignore:
                    # If Ignore signal occures task shouldn't be retried,
                    # even if it suits autoretry_for list
                    raise
                except Retry:
                    raise
                except autoretry_for as exc:
                    if retry_backoff:
                        retry_kwargs['countdown'] = \
                            get_exponential_backoff_interval(
                                factor=retry_backoff,
                                retries=task.request.retries,
                                maximum=retry_backoff_max,
                                full_jitter=retry_jitter)
                    raise task.retry(exc=exc, **retry_kwargs)

            task._orig_run, task.run = task.run, run
    else:
        task = self._tasks[name]
    return task

这段代码的可读性出神入化了已经,name的话比较容易理解, 如果你指定了就用你指定的,如果你没指定,那就自动给你生成一个,生成的规则是:itsm.ticket.tasks.notify_task其实就是路径。这里的fun其实就是我们的函数对象。

我们看到首先是用type关键字给生成一个task对象,这里的task对象是Task类的实例。然后把它加到了app的_tasks列表里面。然后将task与app进行绑定的。特别要注意的是self._tasks这个属性,里面的内容大致上长这个样子:

{
  "itsm.ticket.tasks.notify_task": <@task: itsm.ticket.tasks.notify_task of proj at 0x7fbcecd3d9b0 (v2 compatible)>
}

这样一来我们的task_name 和对应的 执行对象就有了

注意哦,我们的_on_app_finalizers存的可不是什么函数实例,当时我们cons这个函数传进去了,也就是函数此刻还是没执行的。那究竟是在哪里执行的呢? 这个需要我们回到Worker的初始化过程中去。我们前面有提到,Worker对象在初始化过程中有执行到on_before_init函数,然后在这个函数里面我们发现了一段代码: trace.setup_worker_optimizations(self.app, self.hostname)

class Worker(WorkController):
    """Worker as a program."""

    def on_before_init(self, quiet=False, **kwargs):
        self.quiet = quiet
        # 就是这一行
        trace.setup_worker_optimizations(self.app, self.hostname)

        # this signal can be used to set up configuration for
        # workers by name.
        signals.celeryd_init.send(
            sender=self.hostname, instance=self,
            conf=self.app.conf, options=kwargs,
        )
        check_privileges(self.app.conf.accept_content)

当我们赶到setup_worker_optimizations现场的时候,就只发现了这段代码:

# evaluate all task classes by finalizing the app.
app.finalize()

再次回到Celery这个类的finalize方法。

def finalize(self, auto=False):
    """Finalize the app.

    This loads built-in tasks, evaluates pending task decorators,
    reads configuration, etc.
    """
    with self._finalize_mutex:
        if not self.finalized:
            if auto and not self.autofinalize:
                raise RuntimeError('Contract breach: app not finalized')
            self.finalized = True
            _announce_app_finalized(self)

            pending = self._pending
            while pending:
                maybe_evaluate(pending.popleft())

            for task in values(self._tasks):
                task.bind(self)

            self.on_after_finalize.send(sender=self)

发现这样一行代码,点进去看看_announce_app_finalized(self):

def _announce_app_finalized(app):
    callbacks = set(_on_app_finalizers)
    for callback in callbacks:
        callback(app)

执行到这一步,发现,诶,我们的self._tasks 终于算是完整了。这里的设计非常的绕,可能需要多看看才能搞清楚这之间的调用关系。

到这里我们几乎已经回答了上面提到的那个问题,Task是如何被注册到全局上下文的。当我么import某个模块的时候,装饰器会自动执行,所以动态加载项目下所有模块的时候,task就会被注册到celery里面去了。

任务的消费

讲完任务如何发现了,那下面就需要去梳理任务是如何被消费的。在celery中维护了多进程的任务消费模型,但是任务执行的细节并不在本文的范畴中,可能会放到后面的章节中去梳理。

大家现在停下来仔细的思考一下,如果说任务的注册是从@task作为入口开始的,那任务消费的入口又是哪里?

那必然是从接收到消息的那一刻开始的。因为只有接受到了任务才可以开启任务的消费过程。还记得我们在Consumer中发现的那个有用的函数么,on_task_received, 它和kombu的消息回调函数绑定到了一块,也就是当监听到消息队列中有新消息的时候,就会调用on_task_received

def create_task_handler(self, promise=promise):
    strategies = self.strategies
    on_unknown_message = self.on_unknown_message
    on_unknown_task = self.on_unknown_task
    on_invalid_task = self.on_invalid_task
    callbacks = self.on_task_message
    call_soon = self.call_soon

    def on_task_received(message):
        # payload will only be set for v1 protocol, since v2
        # will defer deserializing the message body to the pool.
        payload = None
        try:
            type_ = message.headers['task']                # protocol v2
        except TypeError:
            return on_unknown_message(None, message)
        except KeyError:
            try:
                payload = message.decode()
            except Exception as exc:  # pylint: disable=broad-except
                return self.on_decode_error(message, exc)
            try:
                type_, payload = payload['task'], payload  # protocol v1
            except (TypeError, KeyError):
                return on_unknown_message(payload, message)
        try:
            strategy = strategies[type_]
        except KeyError as exc:
            return on_unknown_task(None, message, exc)
        else:
            try:
                strategy(
                    message, payload,
                    promise(call_soon, (message.ack_log_error,)),
                    promise(call_soon, (message.reject_log_error,)),
                    callbacks,
                )
            except (InvalidTaskError, ContentDisallowed) as exc:
                return on_invalid_task(payload, message, exc)
            except DecodeError as exc:
                return self.on_decode_error(message, exc)

    return on_task_received

这里有个地方非常值得我们注意,self.strategies, 我们的tasks呢???,上面讲半天,好家伙,根本没用着。

大可不必,其中必有引擎,我们找一下self.strategies的初始化逻辑:

def update_strategies(self):
    loader = self.app.loader
    for name, task in items(self.app.tasks):
        self.strategies[name] = task.start_strategy(self.app, self)
        task.__trace__ = build_tracer(name, task, loader, self.hostname,
                                      app=self.app)

这不联系上了吗,再点进去start_sreategy方法看看是干啥的:

def start_strategy(self, app, consumer, **kwargs):
    return instantiate(self.Strategy, self, app, consumer, **kwargs)

发现是生成了一个Strategy对象的实例。追查发现,Strategy对应的位置是: celery.worker.strategy:default,点进去看看,说不定能发现什么意外惊喜呢。

def default(task, app, consumer,
            info=logger.info, error=logger.error, task_reserved=task_reserved,
            to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t,
            proto1_to_proto2=proto1_to_proto2):
    """Default task execution strategy.

    Note:
        Strategies are here as an optimization, so sadly
        it's not very easy to override.
    """
    hostname = consumer.hostname
    connection_errors = consumer.connection_errors
    _does_info = logger.isEnabledFor(logging.INFO)

    # task event related
    # (optimized to avoid calling request.send_event)
    eventer = consumer.event_dispatcher
    events = eventer and eventer.enabled
    send_event = eventer and eventer.send
    task_sends_events = events and task.send_events

    call_at = consumer.timer.call_at
    apply_eta_task = consumer.apply_eta_task
    rate_limits_enabled = not consumer.disable_rate_limits
    get_bucket = consumer.task_buckets.__getitem__
    handle = consumer.on_task_request
    limit_task = consumer._limit_task
    limit_post_eta = consumer._limit_post_eta
    body_can_be_buffer = consumer.pool.body_can_be_buffer
    Request = symbol_by_name(task.Request)
    Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)

    revoked_tasks = consumer.controller.state.revoked

    def task_message_handler(message, body, ack, reject, callbacks,
                             to_timestamp=to_timestamp):
        if body is None and 'args' not in message.payload:
            body, headers, decoded, utc = (
                message.body, message.headers, False, app.uses_utc_timezone(),
            )
            if not body_can_be_buffer:
                body = bytes(body) if isinstance(body, buffer_t) else body
        else:
            if 'args' in message.payload:
                body, headers, decoded, utc = hybrid_to_proto2(message,
                                                               message.payload)
            else:
                body, headers, decoded, utc = proto1_to_proto2(message, body)

        req = Req(
            message,
            on_ack=ack, on_reject=reject, app=app, hostname=hostname,
            eventer=eventer, task=task, connection_errors=connection_errors,
            body=body, headers=headers, decoded=decoded, utc=utc,
        )
        if _does_info:
            info('Received task: %s', req)
        if (req.expires or req.id in revoked_tasks) and req.revoked():
            return

        signals.task_received.send(sender=consumer, request=req)

        if task_sends_events:
            send_event(
                'task-received',
                uuid=req.id, name=req.name,
                args=req.argsrepr, kwargs=req.kwargsrepr,
                root_id=req.root_id, parent_id=req.parent_id,
                retries=req.request_dict.get('retries', 0),
                eta=req.eta and req.eta.isoformat(),
                expires=req.expires and req.expires.isoformat(),
            )

        bucket = None
        eta = None
        if req.eta:
            try:
                if req.utc:
                    eta = to_timestamp(to_system_tz(req.eta))
                else:
                    eta = to_timestamp(req.eta, app.timezone)
            except (OverflowError, ValueError) as exc:
                error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
                      req.eta, exc, req.info(safe=True), exc_info=True)
                req.reject(requeue=False)
        if rate_limits_enabled:
            bucket = get_bucket(task.name)

        if eta and bucket:
            consumer.qos.increment_eventually()
            return call_at(eta, limit_post_eta, (req, bucket, 1),
                           priority=6)
        if eta:
            consumer.qos.increment_eventually()
            call_at(eta, apply_eta_task, (req,), priority=6)
            return task_message_handler
        if bucket:
            return limit_task(req, bucket, 1)

        task_reserved(req)
        if callbacks:
            [callback(req) for callback in callbacks]
        handle(req)
    return task_message_handler

哇靠,有点意思,也就是说,celery再收到消息之后,马上调用了这个消息对应的执行策略,默认的是default, 点进去看看它到底干了啥。注意,实际上strategies对应的value是task_message_handler这个函数。

这里特别需要留意的地方是这一句。

Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)

点进去看看创建了个啥,嘿,有点意思:

def create_request_cls(base, task, pool, hostname, eventer,
                       ref=ref, revoked_tasks=revoked_tasks,
                       task_ready=task_ready, trace=trace_task_ret):
    default_time_limit = task.time_limit
    default_soft_time_limit = task.soft_time_limit
    apply_async = pool.apply_async
    acks_late = task.acks_late
    events = eventer and eventer.enabled

    class Request(base):

        def execute_using_pool(self, pool, **kwargs):
            task_id = self.task_id
            if (self.expires or task_id in revoked_tasks) and self.revoked():
                raise TaskRevokedError(task_id)

            time_limit, soft_time_limit = self.time_limits
            # 注意这里
            result = apply_async(
                trace,
                args=(self.type, task_id, self.request_dict, self.body,
                      self.content_type, self.content_encoding),
                accept_callback=self.on_accepted,
                timeout_callback=self.on_timeout,
                callback=self.on_success,
                error_callback=self.on_failure,
                soft_timeout=soft_time_limit or default_soft_time_limit,
                timeout=time_limit or default_time_limit,
                correlation_id=task_id,
            )
            # cannot create weakref to None
            # pylint: disable=attribute-defined-outside-init
            self._apply_result = maybe(ref, result)
            return result

        def on_success(self, failed__retval__runtime, **kwargs):
            failed, retval, runtime = failed__retval__runtime
            if failed:
                if isinstance(retval.exception, (
                        SystemExit, KeyboardInterrupt)):
                    raise retval.exception
                return self.on_failure(retval, return_ok=True)
            task_ready(self)

            if acks_late:
                self.acknowledge()

            if events:
                self.send_event(
                    'task-succeeded', result=retval, runtime=runtime,
                )

    return Request

这里实际上返回了一个内置类,Request, 他有一个方法 execute_using_pool,然后它还调用了apply_async, 注意,这里真的就开始执行我们的任务了,因为任务名,参数啥的都给传进去了。 第二个从方法名上也可以看到,大概率是和并发有关系的,不过这个不属于本章的重点,让我们现在回到task_message_handler函数去,这次我们只看重点代码:

def task_message_handler(message, body, ack, reject, callbacks,
                         to_timestamp=to_timestamp):

    req = Req(
        message,
        on_ack=ack, on_reject=reject, app=app, hostname=hostname,
        eventer=eventer, task=task, connection_errors=connection_errors,
        body=body, headers=headers, decoded=decoded, utc=utc,
    )
    # 打印日志
    if _does_info:
        info('Received task: %s', req)
    if (req.expires or req.id in revoked_tasks) and req.revoked():
        return
	  # 发送信号
    signals.task_received.send(sender=consumer, request=req)
    # 发送事件
    if task_sends_events:
        send_event(
            'task-received',
            uuid=req.id, name=req.name,
            args=req.argsrepr, kwargs=req.kwargsrepr,
            root_id=req.root_id, parent_id=req.parent_id,
            retries=req.request_dict.get('retries', 0),
            eta=req.eta and req.eta.isoformat(),
            expires=req.expires and req.expires.isoformat(),
        )

    bucket = None
    eta = None
    # 处理延时任务
    if req.eta:
        try:
            if req.utc:
                eta = to_timestamp(to_system_tz(req.eta))
            else:
                eta = to_timestamp(req.eta, app.timezone)
        except (OverflowError, ValueError) as exc:
            error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
                  req.eta, exc, req.info(safe=True), exc_info=True)
            req.reject(requeue=False)
    if rate_limits_enabled:
        bucket = get_bucket(task.name)

    if eta and bucket:
        consumer.qos.increment_eventually()
        return call_at(eta, limit_post_eta, (req, bucket, 1),
                       priority=6)
    if eta:
        consumer.qos.increment_eventually()
        call_at(eta, apply_eta_task, (req,), priority=6)
        return task_message_handler
    if bucket:
        return limit_task(req, bucket, 1)

    task_reserved(req)
    # 如果注册了回调函数,则调用回调函数
    if callbacks:
        [callback(req) for callback in callbacks]
    # 重点来了,重点来了
    handle(req)
return task_message_handler

通过查看源码,handle实际上绑定的函数是handle = consumer.on_task_requestconsumer.on_task_request也不是Consumer类的,是初始化的时候传进去的,继续往上找:

class Consumer(bootsteps.StartStopStep):
    """Bootstep starting the Consumer blueprint."""

    last = True

    def create(self, w):
        if w.max_concurrency:
            prefetch_count = max(w.max_concurrency, 1) * w.prefetch_multiplier
        else:
            prefetch_count = w.concurrency * w.prefetch_multiplier
        c = w.consumer = self.instantiate(
            w.consumer_cls, w.process_task,
            hostname=w.hostname,
            task_events=w.task_events,
            init_callback=w.ready_callback,
            initial_prefetch_count=prefetch_count,
            pool=w.pool,
            timer=w.timer,
            app=w.app,
            controller=w,
            hub=w.hub,
            worker_options=w.options,
            disable_rate_limits=w.disable_rate_limits,
            prefetch_multiplier=w.prefetch_multiplier,
        )
        return c

发现consumer.on_task_request实际指向的是 w.process_task, 有点眉目了,继续往Worker找,只发现了这个:

def _process_task(self, req):
    """Process task by sending it to the pool of workers."""
    try:
        req.execute_using_pool(self.pool)
    except TaskRevokedError:
        try:
            self._quick_release()   # Issue 877
        except AttributeError:
            pass

只到这里,才把execute_using_pool调用了。但是问题出现了,_process_task它多了个下划线啊,在Worker类中,我们并没有发现process_task这个函数,根据我们的经验,只能懵了,还记得我说execute_using_pool 看起来和多进程有关系吗? 就算没有关系,那大概率也和Pool有点关系,于是我们找Worker的子组件.

class Pool(bootsteps.StartStopStep):
    def create(self, w):
        semaphore = None
        max_restarts = None
        if w.app.conf.worker_pool in GREEN_POOLS:  # pragma: no cover
            warnings.warn(UserWarning(W_POOL_SETTING))
        threaded = not w.use_eventloop or IS_WINDOWS
        procs = w.min_concurrency
        # 果然在这儿
        w.process_task = w._process_task
        """
        省略部分无关代码
        """
        return pool

    def info(self, w):
        return {'pool': w.pool.info if w.pool else 'N/A'}

    def register_with_event_loop(self, w, hub):
        w.pool.register_with_event_loop(hub)

真相大白,自此,我们已经知道了我们的任务是如何被发现,以及如何被消费的了。但是任务究竟是怎么被具体消费的,我们仍然无从知晓。但是我们已经非常幸运的找到了打开这快空白的入口,那就是result = apply_async()