Java Concurrency in Practice におけるメモイゼーションの実装

Java Concurrency in Practice におけるスレッドセーフなメモイゼーションの実装。そのままです。


A が入力、V が出力の型

public interface Computable<A, V> {
    V compute(A arg) throws InterruptedException;
}


Memoizer は ConcurrentHashMap と Future の利用により synchronized の同期処理が不要。高価な計算処理の結果をキャッシュ。

public class Memoizer<A, V> implements Computable<A, V> {
    private final ConcurrentMap<A, Future<V>> cache =
            new ConcurrentHashMap<A, Future<V>>();
    private final Computable<A, V> computable;

    public Memoizer(Computable<A, V> computable) {
        this.computable = computable;
    }
    public V compute(final A arg) throws InterruptedException {
        while (true) {
            Future<V> future = cache.get(arg);
            if (future == null) {
                Callable<V> eval = new Callable<V>() {
                    public V call() throws InterruptedException {
                        return computable.compute(arg);
                    }
                };
                FutureTask<V> futureTask = new FutureTask<V>(eval);
                future = cache.putIfAbsent(arg, futureTask);
                if (future == null) {
                    future = futureTask;
                    futureTask.run();
                }
            }

            try {
                return future.get();
            } catch (CancellationException e) {
                cache.remove(arg, future);
            } catch (ExecutionException e) {
                throw Lang.launderThrowable(e.getCause());
            }
        }
    }
}


利用例。

public class MemoizerTest {

    private final Computable<Long, Long[]> computable = 
        new Computable<Long, Long[]>() {
            public Long[] compute(Long arg) throws InterruptedException {
                return factor(arg);
            }
        };
    private final Computable<Long, Long[]> cashe = new Memoizer<Long, Long[]>(computable);

    private static Long[] factor(Long n) {
        List<Long> list = new ArrayList<Long>();
        for (long i = 2; i <= n / i; i++) {
            while (n % i == 0) {
                list.add(i); 
                n = n / i;
            }
        }
        if (n > 1) list.add(n);
        return list.toArray(new Long[list.size()]);
    }

    @Test
    public final void testCompute() throws Exception {
        System.out.println(Arrays.deepToString(cashe.compute(9L)));
        System.out.println(Arrays.deepToString(cashe.compute(81L)));
        System.out.println(Arrays.deepToString(cashe.compute(168L)));
    }
}


結果

[3, 3]
[3, 3, 3, 3]
[2, 2, 2, 3, 7]