itertools in Groovy
Generator を手に入れたので、再びハミングの問題に挑戦する。
同じように実装するためには itertools が必要だ。
Python documentation に Python で書かれた仕様が載っているのでそのまま実装する。
Changes
実装する際に Generator の方も変更した。
- 名前を gen にした
- worker thread で NoSuchElementException が発生した場合は StopIteration とみなすことにした
- worker thread の例外を拾って main thread で再スローすることにした(izipLongest で必要)
Itertools
class Itertools { static def gen(Closure generator) { return new Iterable() { @Override Iterator iterator() { def queue = new LinkedList() // size == 0 or 1 def getLock = new java.util.concurrent.Semaphore(0) def putLock = new java.util.concurrent.Semaphore(0) def done = false def error = null def yield = { queue.addLast(it) getLock.release() putLock.acquire() } def worker = new Thread({ try { putLock.acquire() generator(yield) } catch (NoSuchElementException stop) { } catch (InterruptedException ignore) { } catch (Throwable t) { error = t } finally { done = true getLock.release() } }) worker.daemon = true worker.start() return new Iterator() { @Override void finalize() { worker.interrupt() } @Override boolean hasNext() { if (!done && queue.empty) { try { putLock.release() getLock.acquire() } catch (InterruptedException e) { worker.interrupt() return false } if (error) throw error } return !done } @Override def next() { if (hasNext()) return queue.removeFirst() throw new NoSuchElementException() } @Override void remove() { throw new UnsupportedOperationException() } } } } } static def iter(iterable) { return iterable.iterator() } // heapq static def merge(Object... iterables) { def iterators = iterables.toList().collect{ iter(it) } return gen{ yield -> def pqueue = new PriorityQueue(iterators.size(), { a, b -> a[0] <=> b[0] } as Comparator) for (iterator in iterators) { if (iterator.hasNext()) pqueue.add([iterator.next(), iterator]) } while (!pqueue.empty) { def (min, iterator) = pqueue.poll() yield(min) if (iterator.hasNext()) pqueue.add([iterator.next(), iterator]) } } } // Infinite Iterators: static def count(start = 0, step = 1) { def n = start return gen{ yield -> while (true) { yield(n) n += step } } } static def cycle(iterable) { def saved = [] return gen{ yield -> for (element in iterable) { yield(element) saved << element } while (saved) for (element in saved) yield(element) } } static def repeat(object, times = null) { return gen{ yield -> if (times == null) while (true) yield(object) else for (i in 0..<times) yield(object) } } // Iterators terminating on the shortest input sequence: static def chain(Object... iterables) { return gen{ yield -> for (iter in iterables) for (element in iter) yield(element) } } static def compress(data, selectors) { return izip(data, selectors).findAll{ d, s -> s }.collect{ d, s -> d } } static def dropwhile(Closure predicate, iterable) { iterable = iter(iterable) return gen{ yield -> for (x in iterable) if (!predicate(x)) { yield(x) break } for (x in iterable) yield(x) } } static def groupby(iterable, Closure key = Closure.IDENTITY) { def iterator = iter(iterable) def target = new Object() def current = target def value = target def queue = [] as Queue def done = false return new Iterator() { @Override boolean hasNext() { if (!done && queue.empty) { try { while (current == target) { value = iterator.next() current = key(value) } target = current def group = gen{ yield -> while (current == target) { yield(value) value = iterator.next() current = key(value) } } queue << [current, group] } catch (NoSuchElementException stop) { done = true } } return !queue.empty } @Override def next() { if (hasNext()) return queue.remove() throw new NoSuchElementException() } @Override void remove() { throw new UnsupportedOperationException() } } } static def ifilter(predicate = { it as boolean }, iterable) { return gen{ yield -> for (x in iterable) if (predicate(x)) yield(x) } } static def ifilterfalse(predicate = { it as boolean }, iterable) { return gen{ yield -> for (x in iterable) if (!predicate(x)) yield(x) } } static def islice(iterable, Object... args) { def (start, stop, step) = args.size() == 1 ? [null, args[0], null] : [*args] start = start != null ? start : 0 stop = stop != null ? stop : Integer.MAX_VALUE step = step != null ? step : 1 def iterator = iter(gen{ yield -> for (int i = start; i != stop; i += step) yield(i) }) int nexti = iterator.next() return gen{ yield -> iterable.eachWithIndex{ element, i -> if (i == nexti) { yield(element) nexti = iterator.next() } } } } static def imap(Closure function = null, Object... iterables) { iterables = iterables.collect{ iter(it) } return gen{ yield -> while (true) { def args = iterables.collect{ it.next() } if (function == null) yield(args) else yield(function(*args)) } } } static def starmap(Closure function, iterable) { return gen{ yield -> for (args in iterable) yield(function(*args)) } } static def tee(iterable, int n = 2) { def iterator = iter(iterable) def queues = (0..<n).collect{ [] as Queue } return queues.collect{ myqueue -> return gen{ yield -> while (true) { if (myqueue.empty) { def newval = iterator.next() for (d in queues) d << newval } yield(myqueue.remove()) } } } } static def takewhile(Closure predicate, iterable) { return gen{ yield -> for (x in iterable) if (predicate(x)) yield(x) else break } } static def izip(Object... iterables) { def iterators = iterables.toList().collect{ iter(it) } return gen{ yield -> while (iterators) yield(iterators.collect{ it.next() }) } } static class ZipExhausted extends RuntimeException {} static def izipLongest(Map keywords = [:], Object... args) { def fillvalue = keywords.get('fillvalue') def counter = args.size() - 1 def sentinel = { return gen{ yield -> if (!counter) throw new ZipExhausted() counter -= 1 yield(fillvalue) } } def filters = repeat(fillvalue) def iterators = args.collect{ iter(chain(it, sentinel(), filters)) } return gen{ yield -> try { while (iterators) yield(iterators.collect{ it.next() }) } catch(ZipExhausted ignore) {} } } // Combinatoric generators: static def product(Map keywords = [:], Object... args) { def pools = args.toList().collect{ it as List } * keywords.get('repeat', 1) def result = [[]] for (pool in pools) result = gen{ yield -> for (x in result) for (y in pool) yield(x+[y]) }.collect() return gen{ yield -> for (prod in result) yield(prod) } } static def permutations(iterable, r = null) { def pool = iter(iterable).collect() def n = pool.size() r = r == null ? n : r return gen{ yield -> for (indices in product(0..<n, repeat:r)) if (indices.unique().size() == r) yield(pool[indices]) } } static def combinations(iterable, r = null) { def pool = iter(iterable).collect() def n = pool.size() return gen{ yield -> for (indices in permutations(0..<n, r)) if (indices.sort(false) == indices) yield(pool[indices]) } } static def combinationsWithReplacement(iterable, r) { def pool = iter(iterable).collect() def n = pool.size() return gen{ yield -> for (indices in product(0..<n, repeat:r)) if (indices.sort(false) == indices) yield(pool[indices]) } } }
itertools 以外のメソッドで iter と merge を用意した。
Test
import static Itertools.* assert iter(merge([1,3,5,7],[2,4],[6])).collect() == [1,2,3,4,5,6,7] // Infinite Iterators: assert iter(count(10)).take(5).collect() == [10,11,12,13,14] assert iter(cycle("ABCD")).take(12).collect() == ['A','B','C','D','A','B','C','D','A','B','C','D'] assert repeat(10,3).collect() == [10,10,10] // Iterators terminating on the shortest input sequence: assert chain('ABC', 'DEF').collect() == ['A','B','C','D','E','F'] assert compress("ABCDEF", [1,0,1,0,1,1]) == ['A','C','E','F'] assert dropwhile({ it < 5 }, [1,4,6,4,1]).collect() == [6,4,1] assert groupby("AAAABBBCCDAABBB").collect{ k, g -> k } == ["A","B","C","D","A","B"] assert groupby("AAAABBBCCD").collect{ k, g -> g.collect().join() } == ["AAAA","BBB","CC","D"] assert ifilter({ it % 2 }, 0..<10).collect() == [1,3,5,7,9] assert ifilterfalse({ it % 2 }, 0..<10).collect() == [0,2,4,6,8] assert islice("ABCDEFG", 2).collect() == ['A','B'] assert islice("ABCDEFG", 2, 4).collect() == ['C','D'] assert islice("ABCDEFG", 2, null).collect() == ['C','D','E','F','G'] assert islice("ABCDEFG", 0, null, 2).collect() == ['A','C','E','G'] assert imap(Math.&pow, [2,3,10], [5,2,3]).collect() == [32,9,1000] assert starmap(Math.&pow, [[2,5], [3,2], [10,3]]).collect() == [32,9,1000] assert takewhile({ it < 5 }, [1,4,6,4,1]).collect() == [1,4] assert izip("ABCD", "xy").collect()*.join() == ["Ax","By"] assert izipLongest("ABCD", "xy", fillvalue:'-').collect()*.join() == ["Ax","By","C-","D-"] // Combinatoric generators: assert product("ABCD", "xy").collect()*.join() == ["Ax","Ay","Bx","By","Cx","Cy","Dx","Dy"] assert product(0..<2, repeat:3).collect()*.join() == ["000", "001", "010", "011", "100", "101", "110", "111"] assert permutations("ABCD", 2).collect()*.join() == ["AB","AC","AD","BA","BC","BD","CA","CB","CD","DA","DB","DC"] assert permutations(0..<3).collect()*.join() == ["012","021","102","120","201","210"] assert combinations("ABCD", 2).collect()*.join() == ["AB","AC","AD","BC","BD","CD"] assert combinations(0..<4, 3).collect()*.join() == ["012","013","023","123"] assert combinationsWithReplacement("ABC", 2).collect()*.join() == ["AA","AB","AC","BB","BC","CC"]
named argument は product で気づいた。
それまでは List で渡していたので Python 風になって満足。
Cyclical Iterators
import static Itertools.* def raymonds_hamming() { def output def deferred_output = gen{ yield -> for (i in output) yield(i) } def (result, p2, p3, p5) = tee(deferred_output, 4) def m2 = gen{ for (x in p2) it(2*x) } def m3 = gen{ for (x in p3) it(3*x) } def m5 = gen{ for (x in p5) it(5*x) } def merged = merge(m2, m3, m5) def combined = chain([1G], merged) output = gen{ groupby(combined).each{ k, g -> it(k) } } return result } assert iter(raymonds_hamming()).take(20).collect() == [1,2,3,4,5,6,8,9,10,12,15,16,18,20,24,25,27,30,32,36] assert iter(islice(raymonds_hamming(), 1690, 1691)).next() == 2125764000 assert iter(raymonds_hamming())[1690] == 2125764000 // assert iter(islice(raymonds_hamming(), 999999, 1000000)).next() == 519312780448388736089589843750000000000000000000000000000000000000000000000000000000
まわった!
ただ、Python に比べると圧倒的に遅い。
そもそも Python で書かれた仕様のための実装と、最適化された実装の itertools の違いもある。
ただ、以前書いた無限リスト版に比べると速いし1000000番目以外は気にならない。
終わりに
TODO をいくつか
- 遅い原因の一つはおそらく Thread でそれ自体は仕方が無いことだが、ThreadPool を使用すれば多少改善されると思われる
- Iterable で返しているので2度 iterator を呼び出すことができるがそのときの振る舞いが正しいのかわからない
def c = count(10) assert c in Iterable assert takewhile({ it < 15 }, c).collect() == [10,11,12,13,14] assert takewhile({ it < 15 }, c).collect() == [] // or [10,11,12,13,14] ?
count の場合 Python でもこのように動作したが、他は確認していない。
とりあえず、Python と同じように記述できるようになっただけでうれしい。
2012-04-18 更新
permutations の引数が 0 のときの動作が違ったので修正した
for (x in permutations([1,2,3], 0)) println x // 修正前 // [1, 2, 3] // [1, 3, 2] // [2, 1, 3] // [2, 3, 1] // [3, 1, 2] // [3, 2, 1] // 修正後 // []