掘金 后端 ( ) • 2023-05-24 17:22

起因

前段时间写一个demo时,用到了Threadlocal,考虑其优异的实用性,于是在这里便将知识点其进行简单梳理,一起回顾一下这个经典功能。
在日常业务中,Threadlocal应用面也是非常广泛的,比如Hibernate中,用SessionFactory创建session之后,因为session是线程不安全的,里面包含了数据库操作的各种状态信息,如果每个线程都共享一个session,那么麻烦就大了。所以Hibernate将session放到threadlocal中,保证线程安全的同时,也能避免频繁创建和销毁,影响应用性能。

介绍

以自身作为key,存一个变量到当前线程上下文集合里。在线程生存周期内,任何地方任何时间都能获取到此变量。但是有个前提,不能是线程池,因为线程池中的线程是公用的,任务对其来说只是过客,当前任务设置的一个value,如果不清除的话,会被其他任务读取,可能会造成数据不安全和隐藏的bug。
内部数据存储结构示意图【粗糙勿喷】
image.png

源码

下面分为三个部分来介绍ThreadLocal的源码,分别是写入、读取,ThreadLocal的删除方法比较简单,没有什么好介绍的,主要还是看看ThreadLocalMap.remove()方法。

读取

读取操作比较简单,从当前线程中,获取threadlocalMap集合,以当前ThreadLocal的实例为key来获取相应的value,如果map是空,则进行初始化,创建

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

写入

写入的代码比较简单,一眼便知,不需要花费太多的脑力。第一步获取当前线程,读取线程中的ThreadLocalMap集合,然后写入或者初始化。因为与其他线程的操作是隔离的,所以不会有线程安全问题。

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

初始化ThreadLocalMap集合,将当前数据作为初始元素写入Map。

void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

getMap方法中,直接获取的就是Thread线程实例中的threadLocals集合。

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

应用

设定一个ThreadLocal变量,可以是静态的,供线程以此变量为key,来保存特定的数据。下面的例子中,LANGUAGES用来给线程设置语言环境,用于在线程生命周期内的任何地方取用。

    private final static ThreadLocal<LanguageEnum> LANGUAGES = new ThreadLocal<LanguageEnum>();
// 设置数据
    public static void setLanguage(String language) {
        LANGUAGES.set(LanguageEnum.get(language));
    }
// 读取之前保存的数据
    public static LanguageEnum getLanguage() {
        log.info("THREAD_LOCAL,thread:{},language:{}",Thread.currentThread().getName(),LANGUAGES.get());
        return LANGUAGES.get();
    }
// 线程生命周期结束之后,删除数据,以免内存泄露
    public static void remove() {
        LANGUAGES.remove();
    }

延伸

如果一个子线程需要获取父线程中的threadlocal变量,需要如何处理呢?Java语言开发者已经为我们考虑到了这个场景,在Thread中有个集合:【inheritableThreadLocals】,其集合类型和【threadLocals】一样,也是ThreadLocalMap,就是用来存储数据,方便子线程读取的。

    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

下面给出一个测试用例

    @Test
    public void testInheritable() {
        InheritableThreadLocal<String> threadLocal = new InheritableThreadLocal<>();
        threadLocal.set("测试数据-主线程");
        log.info("主线程设置数据:{}", threadLocal.get());
        new Thread(() -> {
            String data = threadLocal.get();
            log.info("子线程获取数据:{}", data);
        }).start();
    }

输出内容如下,可以到主线程设置的变量,在子线程中能获取主线程threadlocal保存的数据。

11:06:46.875 [main] INFO com.neteasexxx.sync.LockTest - 主线程设置数据:测试数据-主线程
11:06:46.934 [Thread-0] INFO com.netease.xxx.sync.LockTest - 子线程获取数据:测试数据-主线程

ThreadLocalMap介绍

这是定义在ThreadLocal中的一个静态内部类,看起来是一个Map,其实并没有实现Map接口,其数据是保存在一个数组结构中,【private Entry[] table; 】。这里有个比较有意思的设定,Entry也不是Map中定义的Entry,而是一个弱引用,key:ThreadLocal,value:Object。一旦引用对象[ThreadLocal]被置为null,表示其不再被引用,这个数据就会被数组擦除,从而被垃圾回收。

什么是弱引用

Java中的弱引用具体指的是java.lang.ref.WeakReference类,我们首先来看一下官方文档对它做的说明:
弱引用对象的存在不会阻止它所指向的对象变被垃圾回收器回收。弱引用最常见的用途是实现规范映射(canonicalizing mappings,比如哈希表)。
本文介绍的ThreadLocalMap就是用弱引用实现的。

为什么要用弱引用

在类中,ThreadLocal被定义为一个弱引用,如果ThreadLocal实例被设置为空,那么垃圾回收器会将其映射的value进行回收,而不需要等到手动设置table[i]=null,可以在一定程度上避免内存泄露。

读取

读取操作逻辑简单,这里就不贴源码了,大概步骤是取key的hash值,如果不为空则返回,如果为空的话,可能发生hash冲突,需要往后遍历查找。如果最终还是没有找到当前key的数据,则返回空。
这里讲解一下关键代码:

    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 找到了,返回
        if (k == key)
            return e;
        // 软引用失效,需要清理数据
        if (k == null)
            expungeStaleEntry(i);
        // 没找到继续往后查找
        else
            i = nextIndex(i, len);
        e = tab[i];
    }

写入

写入方法是ThreadLocalMap类的灵魂核心,所以此处将源代码直接贴出来,不作剪辑。

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;
    // 获取当前key的位置
    int i = key.threadLocalHashCode & (len-1);
    // 如果该位置有数据,则表明发生了hash冲突,需要往后遍历,更新或者替换
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        // 如果key相同,则更新
        if (k == key) {
            e.value = value;
            return;
        }
        // 如果原集合中的槽位数据,其软引用已经丢失,则对当前位置做替换操作
        // 注意:继续往后遍历,是有可能找到此key的,只是需要对这种hash冲突进行处理。
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 如果该位置没有被占用,则新建一个弱引用,设置进table集合
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 如果table数组中的数据已经超过容量的2/3阀值,
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

rehash的时候,先清除无效的Entry,如果当前集合剩余空间不足1/4,则开始扩容

private void rehash() {
    // 清除无效的Entry
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        // 扩容
        resize();
}

删除

删除方法的逻辑比较简单:

  1. 定位到元素真正的槽位【因为可能发生过hash冲突】;
  2. 置空软引用;
  3. 将相应槽位的数据设置为空,并将集合内的数据进行重新hash。
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        // 判断是否发生过Hash冲突,如果有冲突,则需要往后继续遍历
        if (e.get() == key) {
            // 将引用设置为null
            e.clear();
            // 将引用的值设置为null
            expungeStaleEntry(i);
            return;
        }
    }
}

replaceStaleEntry - 替换操作

参数为key,value,staleSlot,方法内分为三步走

  1. 根据参数中提供的槽位位置【staleSlot】,往前找到一个失效的槽位【slotToExpunge】,便于后续清理;
  2. 从staleSlot开始往后找,尝试找到当前key在集合中的旧值,找到以后,将旧值移动到staleSlot的位置,并进行清理,范围【slotToExpunge - length】;
  3. 如果此key没有旧值,则将staleSlot位置的数据设置为null,并将参数中的value设置到staleSlot位置。
    这里看看第二步的关键代码
    if (k == key) {
        e.value = value;
        // 移动位置,i肯定在staleSlot的后边,所以是将旧值往前移动了
        tab[i] = tab[staleSlot];
        tab[staleSlot] = e;

        // Start expunge at preceding stale entry if it exists
        if (slotToExpunge == staleSlot)
            slotToExpunge = i;
        // 清理数据,因为原先的tab[staleSlot]是无效数据,而且后移了,需要对其清理
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        return;
    }

扩容

扩容时,新建一个原先的容量的两倍的Entry数组,将旧数组中的数据,重新hash计算之后,转移到新的队列中,最后重新设置阀值,等待下次扩容。
如果有hash冲突,则继续往后推进,直至找到一个空位,将数据写入。

    // 重新计算hash值
    int h = k.threadLocalHashCode & (newLen - 1);
    // 如果有hash冲突,则继续往后推进,直至找到一个空位,将数据写入
    while (newTab[h] != null)
        h = nextIndex(h, newLen);
    newTab[h] = e;
    count++;

cleanSomeSlots - 整理数据

方法返回一个boolean值,表示是否发生过删除操作。方法内遍历整个Entry数组,判断相应槽位数据的弱引用是否为空,如果为空,则表明需要被清除。

    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        // 弱引用为空
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            // 清理槽位数据,并返回下一个为空的槽位,继续遍历
            i = expungeStaleEntry(i);
        }
        // n是集合的容量
    } while ( (n >>>= 1) != 0);

expungeStaleEntry - 清除特定槽位的数据

此方法返回值是int,照理说清除完成之后,不需要返回什么;文档里是这么说的“清除完成之后,返回下一个为null的槽位”。所以这个方法分为了两个部分:

  1. 清除特定槽位数据
    Entry[] tab = table;
    // expunge entry at staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;
  1. 找到下一个为null的槽位,期间会有rehash的操作,因为上一步清除了一个数据,之前发生的hash冲突可能已经不存在了,所以需要对其进行位置还原,避免后续找不到该数据,所以此处的操作目的不是为了找到空槽位,而是重新hash。
    // 重新计算hash值
    int h = k.threadLocalHashCode & (len - 1);
    // 如果hash值与entry元素的当前位置不一致,则说明之前发生过hash冲突,被迫后移了位置,需要修正。
    if (h != i) {
        tab[i] = null;
        // 找到一个空位,将数据插回去。这里遍历的原因是有可能这个位置发生过不止一次hash冲突,后面可能有多个entry全是要到槽位【h】的,这里操作的最终结果是:相同hash值的entry,整体前移一个槽位。
        while (tab[h] != null)
            h = nextIndex(h, len);
        tab[h] = e;
    }