theme: channing-cyan highlight: a11y-dark
手写ThreadLocal
声明: 本文使用JDK11,threadLocal场景下应该和JDK8没有差异,习惯使然
前言
还在面试的前一天上班囫囵吞枣找资料?还在面试前一晚辗转难眠?还在去面试的路上因为准备不足而忐忑不安?还在每次面试后因为面试官问的问题太刁钻而破口大骂?今天,教你手把手实现ThreadLocal,以后面试吊打面试官,脚踢HR🤣😁。
ThreadLocal用法简介
使用场景
按我的理解, ThreadLocal只是保存上下文的一个工具。 就我自己在项目用到的场景
- 保存当前用户信息; 账号,token等
- 记录调用链路;sessionId,traceId等,跨服务调用时,出了问题方便溯源;
- 打印日志;配合log4j中的MDC,用户某个操作统一加上日志前缀,方便跟踪;
- 缓存;
- 事务管理;保存事务上下文,方便回滚
- 动态数据源切换;
基本用法
public class Test0 {
public static void main(String[] args) {
ThreadLocal<String> threadLocal = new ThreadLocal<>();
threadLocal.set("hi");
CompletableFuture.runAsync(() -> {
threadLocal.set("hello");
System.out.println("thread: " + Thread.currentThread().getName() + ", value: " + threadLocal.get());
});
System.out.println("thread: " + Thread.currentThread().getName() + ", value: " + threadLocal.get());
}
}
结果如下:
thread: main, value: hi
thread: ForkJoinPool.commonPool-worker-19, value: hello
简单来说就是每个线程可以在ThreadLocal保存独立的值,不会互相影响;
核心方法
- set
- 作用:在当前线程保存某个值
- get
- 作用:获取保存的值
实现
基本接口
按照惯例先弄个接口方便迭代
public interface ThreadLocalInf<T> {
void set(T value) ;
T get() ;
}
版本-01
在还没看源码前,我的思路就是用一个Map,key保存线程ID,value保存值;
实现源码:
public class MyThreadLocal1<T> implements ThreadLocalInf<T> {
@Override
public void set(T value) {
getMap().put(Thread.currentThread().getId(), value);
}
@Override
public T get() {
return getMap().get(Thread.currentThread().getId());
}
private final Map<Long, T> map = new ConcurrentHashMap<>();
public Map<Long,T> getMap() {
return map;
}
public static void main(String[] args) {
ThreadLocalInf<String> threadLocal = new MyThreadLocal1<>();
threadLocal.set("hi");
CompletableFuture.runAsync(() -> {
threadLocal.set("hello");
System.out.println("thread: " + Thread.currentThread().getName() + ", value: " + threadLocal.get());
});
System.out.println("thread: " + Thread.currentThread().getName() + ", value: " + threadLocal.get());
}
}
这样实现就是简单直观,缺点就是要用到ConcurrentHashMap来保证线程安全,高并发环境效率相对低一点。
要是产品要我实现一个ThreadLocal,我就这么写了🐶
版本-02
这会我读了下ThreadLocal的源码之后发现, 它并非使用ThreadLocal持有Map,而是使用Thread持有;好处就是不用担心线程安全问题。
JDK源码:
public
class Thread implements Runnable {
//...
/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
//...
}
static class ThreadLocalMap {
/**
* Set the value associated with key.
*
* @param key the thread local object
* @param value the value to be set
*/
private void set(ThreadLocal<?> key, Object value) {
//...
}
}
所以说我们这里还要自己实现一个Thread😅;
并且通过源码我们可以发现,这个ThreadLocalMap是以ThreadLocal作为key的,所以接下来我们实现MyThreadLocal2还得实现一下hashcode和equals方法
考虑到我们没办法修改jdk中thread的代码,我们自己实现一个MyThread类,然后里面维护一个Map
class MyThread2 extends Thread {
Map<ThreadLocalInf<?>, Object> threadLocalMap = new HashMap<>();
public MyThread2(Runnable runnable) {
super(runnable);
}
}
public class MyThreadLocal2<T> implements ThreadLocalInf<T> {
private static final AtomicInteger nextId = new AtomicInteger(0);
private final int id = nextId.getAndIncrement();
// hashCode没必要写得太复杂,因为每个ThreadLocal都是唯一的,给出一个自增的id就可以了
@Override
public int hashCode() {
return id;
}
// 这里equals == 即可,因为每个ThreadLocal都是唯一的
@Override
public boolean equals(Object obj) {
return this == obj ;
}
@Override
public void set(T value) {
Thread thread = Thread.currentThread();
if(thread instanceof MyThread2) {
MyThread2 myThread = (MyThread2) thread;
myThread.threadLocalMap.put(this, value);
} else {
throw new UnsupportedOperationException();
}
}
@Override
public T get() {
Thread thread = Thread.currentThread();
if( thread instanceof MyThread2) {
MyThread2 myThread = (MyThread2) thread;
return (T) myThread.threadLocalMap.get(this);
} else {
throw new UnsupportedOperationException();
}
}
public static void main(String[] args) {
// 创建线程池 , 使用MyThread
ExecutorService executorService = Executors.newCachedThreadPool(MyThread2::new);
// 创建10个ThreadLocal
List<ThreadLocalInf<String>> localList = new ArrayList<>();
for (int i = 0; i < 10; i++) {
localList.add(new MyThreadLocal2<>());
}
//这里我们上一下强度, 开100个线程测试
for (int i = 0; i < 100; i++) {
CompletableFuture.runAsync(() -> {
for (int j = 0; j < localList.size(); j++) {
String val = Thread.currentThread().getName() + "-" + j;
ThreadLocalInf<String> local = localList.get(j);
System.out.println("thread :" + Thread.currentThread().getName() + ",set value: " + val);
local.set(val);
}
// 暂停5秒
try { TimeUnit.SECONDS.sleep(5); } catch (InterruptedException e) { throw new RuntimeException(e); }
for (ThreadLocalInf<String> local : localList) {
System.out.println("thread :" + Thread.currentThread().getName() + ",get value: " + local.get());
}
}, executorService);
}
try {
TimeUnit.SECONDS.sleep(10);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
System.out.println("finish");
}
}
到这里你应该也基本明白ThreadLocal的构造, 接下来我们再看看细节。
版本-03
ThreadLocal可能就是这么简单,但一旦到了面试的场景,防止面试官使劲扣点东西来问,我们还是得再深入研究一下。
接下来说说老生常谈的内存泄漏问题。
为什么会出现内存泄漏
内存泄漏简单来说就是不再使用的内存无法被回收,导致内存占用越来越大,最终导致OOM(Out of Memory 内存溢出)。
我们经常使用的springboot每个请求都会使用一个线程,请求结束后线程并不会销毁,而是放到线程池中,等待下一次请求;如果在这个请求结束前,你保存了大量的数据到ThreadLocal中,但没有主动remove,而且这个Thread由于使用的是线程池,是会一直存在的,它所保存的threadLocals的对象也会一直存在。那么这部分数据就会一直存在内存中,从而很容易导致内存泄漏。
因此内存泄漏通常是因为我们没有主动remove。弱引用在清理内存上面只起到了很小的作用,如果开发的过程中主动remove,那么完全可以不用弱引用。
ThreadLocal中弱引用的作用,以及ThreadLocal对弱引用的后续处理
简单来说,发生GC后,并且ThreadLocal没有其他强引用,ThreadLocalMap中的Entry的key就会被回收(等同于调用weakReference.clear())变为null。
这里说一下题外话,我们一般使用ThreadLocal 会这么定义:public static final ThreadLocal local = new ThreadLocal<>(); 这样做相当于加了个不可更改的强引用,因此,ThreadLocalMap中的Entry的key是不会被回收;所以我认为开发的过程中不必太在意这个Entry中的WeakReference。
这里示范一下使用弱引用的例子
例子代码:
/**
* JDK11
*/
public class WeakReferenceExample {
public static void main(String[] args) {
String str = new String("Hello, World!"); //强制创建对象在堆中,而不是常量池中;常量池中的对象不会被回收,哪怕只有弱引用
// String str = "Hello, World!" //对象会在常量池
WeakReference<String> weakReference = new WeakReference<>(str);
str = null; //尝试取消注释该行代码,看看效果
System.gc(); //相当于调。weakReference.clear()
System.out.println(weakReference.get()); //返回null
}
}
必须手动str = null; 弱引用才能生效。也就是gc后清理没有强引用的对象。
如果返回null,那么就是弱引用里面的对象被回收了。
JDK源码:
/**
* The entries in this hash map extend WeakReference, using
* its main ref field as the key (which is always a
* ThreadLocal object). Note that null keys (i.e. entry.get()
* == null) mean that the key is no longer referenced, so the
* entry can be expunged from table. Such entries are referred to
* as "stale entries" in the code that follows.
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
我们先做个假设,哪怕这个WeakReference生效了 ,就是GC后回收了key,但是value是强引用(只有被super(k)框住的才是弱引用😥)。一般来说ThreadLocal作为key本身是不太占用内存的,但是value是用户传值,占用内存可能会很大,那么ThreadLocal是如何自动清理掉没用的value的?
这里先说结论,ThreadLocal调用get()或者set()或者内部map触发扩容的时候,都会检查对应的key是否为null,如果是null,就会把这个Entry的value置为null;
感兴趣的话可以阅读下源码ThreadLocal 里面的 expungeStaleEntry,它的作用除了清除key为空的entry外还重新排列与被清空的key产生hash冲突的元素的索引,这里就不贴了代码了,免得你们以为我刷字数。
通过上面的解释,相信你也知道这个WeakReference的确没什么用了😎(可能有但是我的使用场景用不上)。
版本-03代码实现
接下来我们也实现一下ThreadLocalMap中的对entry的null值处理
这里我就不用Map了,直接用数组代替map的存储功能,否则实现处理hash冲突的代码太长了。
重新列一下我们要实现的功能
- set: 保存值
- 如果key为null,直接替换value , 如果需要扩容,清理所有key为null的entry
- get: 获取值
- 如果entry的key为null,删除entry
实现代码:
/**
* 取消hash冲突的实现,简单使用List保存Entry
*/
class ThreadLocalMap {
//照搬ThreadLocalMap.Entry
static class Entry extends WeakReference<ThreadLocalInf<?>> {
Object value;
Entry(ThreadLocalInf<?> k, Object v) {
super(k);
value = v;
}
}
private final List<Entry> table = new ArrayList<>();
private int getIndex(ThreadLocalInf<?> key) {
return key.hashCode();
}
public Object get(ThreadLocalInf<?> key) {
return getByIndex(getIndex(key));
}
public Object getByIndex(int index) {
Entry entry = table.get(index);
if (entry == null) {
return null;
}
if (entry.get() == null) {
entry.value = null;
table.remove(entry);
return null;
} else {
return entry.value;
}
}
/**
* 扩容时清理无效的Entry
*/
public void put(ThreadLocalInf<?> key, Object value) {
int index = getIndex(key);
// 扩容
while (table.size() <= index) {
table.add(null);
for (int i = 0; i < table.size(); i++) {
if (table.get(i) != null && table.get(i).get() == null) {
table.set(i, null);
}
}
}
table.set(index, new Entry(key, value));
}
}
/**
* 自定义线程类,为了自定义ThreadLocalMap
*/
class MyThread3 extends Thread {
ThreadLocalMap threadLocalMap = new ThreadLocalMap();
public MyThread3(Runnable runnable) {
super(runnable);
}
}
public class MyThreadLocal3<T> implements ThreadLocalInf<T> {
private static final AtomicInteger nextId = new AtomicInteger(0);
private final int id = nextId.getAndIncrement();
@Override
public int hashCode() {
return id;
}
@Override
public boolean equals(Object obj) {
return this == obj;
}
@Override
public void set(T value) {
Thread thread = Thread.currentThread();
if (thread instanceof MyThread3) {
MyThread3 myThread = (MyThread3) thread;
myThread.threadLocalMap.put(this, value);
} else {
throw new UnsupportedOperationException();
}
}
@Override
public T get() {
Thread thread = Thread.currentThread();
if (thread instanceof MyThread3) {
MyThread3 myThread = (MyThread3) thread;
return (T) myThread.threadLocalMap.get(this);
} else {
throw new UnsupportedOperationException();
}
}
}
以上就是最终的版本实现,上面的代码还有缺陷,就是list会无限扩容,有兴趣的可以自行优化下,思路就是新增ThreadLocal的时候给个最小的可用索引。
有兴趣可以参考上述版本02做一些测试。 上述代码参考了netty源码中的 io.netty.util.concurrent.FastThreadLocal ,它也是使用数组实现,不存在hash冲突。有兴趣的同学可以去学习一下