InheritableThreadLocal の使い方で、ハマったのでメモ。

Oracle Technology Network for Java Developers | Oracle Technology Network | Oracle

Javaでセッションの情報などを、ThreadLocalに保存しておいて後から参照するのは良くある使い方だと思います。しかし、今回では1セッション中にスレッドを作って処理を分散させる必要がありました。
そのためにお誂え向きな機能が JDKに用意されていて、それが表題の InheritableThreadLocalクラスになります。具体的には、あるスレッドで InheritableThreadLocalクラスのインスタンスを生成した場合に、そのスレッドから新規に作成されたスレッドに対して InheritableThreadLocalの情報が引き継がれます。その点が Inheritableなんですね。それ以外の部分は、ThreadLocalを継承している通り各スレッドで別々のインスタンスが生成されることになります。

さて、今回このクラスを使った際に問題が起きたコードは以下のものです。

import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;

public class InheritableThreadLocalSample {
	private static final InheritableThreadLocal<Map<String, String>> info = new InheritableThreadLocal<Map<String, String>>() {
		@Override
		protected Map<String, String> initialValue() {
			return new TreeMap<String, String>();
		}
	};
	public static String get(String key) {
		return info.get().get(key);
	}
	public static void put(String key, String value) {
		info.get().put(key, value);
	}
	public static String getAll() {
		StringBuilder builder = new StringBuilder();
		builder.append('[');
		for (Entry<String, String> entry : info.get().entrySet()) {
			builder.append(entry.getKey()).append('=').append(entry.getValue()).append(',');
		}
		builder.append(']');
		return builder.toString();
	}
}

普通のThreadLocalでは、まったく問題がなさそうな実装なのですが、情報が引き継がれるところでバグが起きてしまいました。
InheritableThreadLocalクラスは、ThreadLocalを継承しているのですがメソッドが1つ追加されています。それが childValueメソッドです。あるスレッドから別のスレッドが作成された際に、このメソッドが呼び出されて情報が引き継がれます。そして、デフォルトの実装では前の値をそのまま返すだけになっています。
そんな訳で、今回のバグは ThreadLocalのつもりだったMapのインスタンスがそのまま別のスレッドに渡されて使われてしまっていた。さらに、そのMapのインスタンスが Mainスレッドで生成されたものが全てのスレッドに渡されていたため、無関係だと思っていた場所での変更が波及してしまっていた、というものでした。
修正するのは簡単で、上記の childValueメソッドをオーバーライドして新たなインスタンスを作るだけでした。修正後のソースは以下の通り。

import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;

public class InheritableThreadLocalSample {
	private static final InheritableThreadLocal<Map<String, String>> info = new InheritableThreadLocal<Map<String, String>>() {
		@Override
		protected Map<String, String> initialValue() {
			return new TreeMap<String, String>();
		}
		@Override
		protected Map<String, String> childValue(Map<String, String> parentValue) {
			return new TreeMap<String, String>(parentValue);
		}
	};
	public static String get(String key) {
		return info.get().get(key);
	}
	public static void put(String key, String value) {
		info.get().put(key, value);
	}
	public static String getAll() {
		StringBuilder builder = new StringBuilder();
		builder.append('[');
		for (Entry<String, String> entry : info.get().entrySet()) {
			builder.append(entry.getKey()).append('=').append(entry.getValue()).append(',');
		}
		builder.append(']');
		return builder.toString();
	}
}