itertools in Groovy

Generator を手に入れたので、再びハミングの問題に挑戦する。

同じように実装するためには itertools が必要だ。
Python documentation に Python で書かれた仕様が載っているのでそのまま実装する。

Changes

実装する際に Generator の方も変更した。

  • 名前を gen にした
  • worker thread で NoSuchElementException が発生した場合は StopIteration とみなすことにした
  • worker thread の例外を拾って main thread で再スローすることにした(izipLongest で必要)

Itertools

class Itertools {
  static def gen(Closure generator) {
    return new Iterable() {
      @Override Iterator iterator() {
        def queue   = new LinkedList()  // size == 0 or 1
        def getLock = new java.util.concurrent.Semaphore(0)
        def putLock = new java.util.concurrent.Semaphore(0)
        def done    = false
        def error   = null
        def yield = {
          queue.addLast(it)
          getLock.release()
          putLock.acquire()
        }
        def worker = new Thread({
          try {
            putLock.acquire()
            generator(yield)
          } catch (NoSuchElementException stop) {
          } catch (InterruptedException ignore) {
          } catch (Throwable t) {
            error = t
          } finally {
            done = true
            getLock.release()
          }
        })

        worker.daemon = true
        worker.start()
        return new Iterator() {
          @Override void finalize() {
            worker.interrupt()
          }

          @Override boolean hasNext() {
            if (!done && queue.empty) {
              try {
                putLock.release()
                getLock.acquire()
              } catch (InterruptedException e) {
                worker.interrupt()
                return false
              }
              if (error) throw error
            }
            return !done
          }

          @Override def next() {
            if (hasNext()) return queue.removeFirst()
            throw new NoSuchElementException()
          }

          @Override void remove() {
            throw new UnsupportedOperationException()
          }
        }
      }
    }
  }

  static def iter(iterable) {
    return iterable.iterator()
  }

  // heapq
  static def merge(Object... iterables) {
    def iterators = iterables.toList().collect{ iter(it) }
    return gen{ yield ->
      def pqueue = new PriorityQueue(iterators.size(), { a, b -> a[0] <=> b[0] } as Comparator)
      for (iterator in iterators) {
        if (iterator.hasNext()) pqueue.add([iterator.next(), iterator])
      }
      while (!pqueue.empty) {
        def (min, iterator) = pqueue.poll()
        yield(min)
        if (iterator.hasNext()) pqueue.add([iterator.next(), iterator])
      }
    }
  }

  // Infinite Iterators:
  static def count(start = 0, step = 1) {
    def n = start
    return gen{ yield ->
      while (true) {
        yield(n)
        n += step
      }
    }
  }

  static def cycle(iterable) {
    def saved = []
    return gen{ yield ->
      for (element in iterable) {
        yield(element)
        saved << element
      }
      while (saved)
        for (element in saved)
          yield(element)
    }
  }

  static def repeat(object, times = null) {
    return gen{ yield ->
      if (times == null)
        while (true)
          yield(object)
      else
        for (i in 0..<times)
          yield(object)
    }
  }

  // Iterators terminating on the shortest input sequence:
  static def chain(Object... iterables) {
    return gen{ yield ->
      for (iter in iterables)
        for (element in iter)
          yield(element)
    }
  }

  static def compress(data, selectors) {
    return izip(data, selectors).findAll{ d, s -> s }.collect{ d, s -> d }
  }

  static def dropwhile(Closure predicate, iterable) {
    iterable = iter(iterable)
    return gen{ yield ->
      for (x in iterable)
        if (!predicate(x)) {
          yield(x)
          break
        }
      for (x in iterable)
        yield(x)
    }
  }

  static def groupby(iterable, Closure key = Closure.IDENTITY) {
    def iterator = iter(iterable)
    def target   = new Object()
    def current  = target
    def value    = target
    def queue    = [] as Queue
    def done     = false
    return new Iterator() {
      @Override boolean hasNext() {
        if (!done && queue.empty) {
          try {
            while (current == target) {
              value   = iterator.next()
              current = key(value)
            }
            target = current
            def group  = gen{ yield ->
              while (current == target) {
                yield(value)
                value   = iterator.next()
                current = key(value)
              }
            }
            queue << [current, group]
          } catch (NoSuchElementException stop) {
            done = true
          }
        }
        return !queue.empty
      }

      @Override def next() {
        if (hasNext()) return queue.remove()
        throw new NoSuchElementException()
      }

      @Override void remove() {
        throw new UnsupportedOperationException()
      }
    }
  }

  static def ifilter(predicate = { it as boolean }, iterable) {
    return gen{ yield ->
      for (x in iterable)
        if (predicate(x))
          yield(x)
    }
  }

  static def ifilterfalse(predicate = { it as boolean }, iterable) {
    return gen{ yield ->
      for (x in iterable)
        if (!predicate(x))
          yield(x)
    }
  }

  static def islice(iterable, Object... args) {
    def (start, stop, step) = args.size() == 1 ? [null, args[0], null] : [*args]
    start = start != null ? start : 0
    stop  = stop  != null ? stop  : Integer.MAX_VALUE
    step  = step  != null ? step  : 1
    def iterator = iter(gen{ yield ->
      for (int i = start; i != stop; i += step)
        yield(i)
    })
    int nexti = iterator.next()
    return gen{ yield ->
      iterable.eachWithIndex{ element, i ->
        if (i == nexti) {
            yield(element)
            nexti = iterator.next()
        }
      }
    }
  }

  static def imap(Closure function = null, Object... iterables) {
    iterables = iterables.collect{ iter(it) }
    return gen{ yield ->
      while (true) {
        def args = iterables.collect{ it.next() }
        if (function == null)
          yield(args)
        else
          yield(function(*args))
      }
    }
  }

  static def starmap(Closure function, iterable) {
    return gen{ yield ->
      for (args in iterable)
        yield(function(*args))
    }
  }

  static def tee(iterable, int n = 2) {
    def iterator = iter(iterable)
    def queues   = (0..<n).collect{ [] as Queue }
    return queues.collect{ myqueue ->
      return gen{ yield ->
        while (true) {
          if (myqueue.empty) {
            def newval = iterator.next()
            for (d in queues)
              d << newval
          }
          yield(myqueue.remove())
        }
      }
    }
  }

  static def takewhile(Closure predicate, iterable) {
    return gen{ yield ->
      for (x in iterable)
        if (predicate(x))
          yield(x)
        else
          break
    }
  }

  static def izip(Object... iterables) {
    def iterators = iterables.toList().collect{ iter(it) }
    return gen{ yield ->
      while (iterators)
        yield(iterators.collect{ it.next() })
    }
  }

  static class ZipExhausted extends RuntimeException {}
  static def izipLongest(Map keywords = [:], Object... args) {
    def fillvalue = keywords.get('fillvalue')
    def counter   = args.size() - 1
    def sentinel = {
      return gen{ yield ->
        if (!counter) throw new ZipExhausted()
        counter -= 1
        yield(fillvalue)
      }
    }
    def filters   = repeat(fillvalue)
    def iterators = args.collect{ iter(chain(it, sentinel(), filters)) }
    return gen{ yield ->
      try {
        while (iterators)
          yield(iterators.collect{ it.next() })
      } catch(ZipExhausted ignore) {}
    }
  }

  // Combinatoric generators:
  static def product(Map keywords = [:], Object... args) {
    def pools  = args.toList().collect{ it as List } * keywords.get('repeat', 1)
    def result = [[]]
    for (pool in pools)
      result = gen{ yield ->
        for (x in result)
          for (y in pool)
            yield(x+[y])
      }.collect()
    return gen{ yield ->
      for (prod in result)
        yield(prod)
    }
  }

  static def permutations(iterable, r = null) {
    def pool = iter(iterable).collect()
    def n    = pool.size()
    r = r == null ? n : r
    return gen{ yield ->
      for (indices in product(0..<n, repeat:r))
        if (indices.unique().size() == r)
          yield(pool[indices])
    }
  }

  static def combinations(iterable, r = null) {
    def pool = iter(iterable).collect()
    def n    = pool.size()
    return gen{ yield ->
      for (indices in permutations(0..<n, r))
        if (indices.sort(false) == indices)
          yield(pool[indices])
    }
  }

  static def combinationsWithReplacement(iterable, r) {
    def pool = iter(iterable).collect()
    def n    = pool.size()
    return gen{ yield ->
      for (indices in product(0..<n, repeat:r))
        if (indices.sort(false) == indices)
          yield(pool[indices])
    }
  }
}

itertools 以外のメソッドで iter と merge を用意した。

Test

import static Itertools.*

assert iter(merge([1,3,5,7],[2,4],[6])).collect() == [1,2,3,4,5,6,7]

// Infinite Iterators:
assert iter(count(10)).take(5).collect() == [10,11,12,13,14]
assert iter(cycle("ABCD")).take(12).collect() == ['A','B','C','D','A','B','C','D','A','B','C','D']
assert repeat(10,3).collect() == [10,10,10]

// Iterators terminating on the shortest input sequence:
assert chain('ABC', 'DEF').collect() == ['A','B','C','D','E','F']
assert compress("ABCDEF", [1,0,1,0,1,1]) == ['A','C','E','F']
assert dropwhile({ it < 5 }, [1,4,6,4,1]).collect() == [6,4,1]
assert groupby("AAAABBBCCDAABBB").collect{ k, g -> k } == ["A","B","C","D","A","B"]
assert groupby("AAAABBBCCD").collect{ k, g -> g.collect().join() } == ["AAAA","BBB","CC","D"]
assert ifilter({ it % 2 }, 0..<10).collect() == [1,3,5,7,9]
assert ifilterfalse({ it % 2 }, 0..<10).collect() == [0,2,4,6,8]
assert islice("ABCDEFG", 2).collect() == ['A','B']
assert islice("ABCDEFG", 2, 4).collect() == ['C','D']
assert islice("ABCDEFG", 2, null).collect() == ['C','D','E','F','G']
assert islice("ABCDEFG", 0, null, 2).collect() == ['A','C','E','G']
assert imap(Math.&pow, [2,3,10], [5,2,3]).collect() == [32,9,1000]
assert starmap(Math.&pow, [[2,5], [3,2], [10,3]]).collect() == [32,9,1000]
assert takewhile({ it < 5 }, [1,4,6,4,1]).collect() == [1,4]
assert izip("ABCD", "xy").collect()*.join() == ["Ax","By"]
assert izipLongest("ABCD", "xy", fillvalue:'-').collect()*.join() == ["Ax","By","C-","D-"]

// Combinatoric generators:
assert product("ABCD", "xy").collect()*.join() == ["Ax","Ay","Bx","By","Cx","Cy","Dx","Dy"]
assert product(0..<2, repeat:3).collect()*.join() == ["000", "001", "010", "011", "100", "101", "110", "111"]
assert permutations("ABCD", 2).collect()*.join() == ["AB","AC","AD","BA","BC","BD","CA","CB","CD","DA","DB","DC"]
assert permutations(0..<3).collect()*.join() == ["012","021","102","120","201","210"]
assert combinations("ABCD", 2).collect()*.join() == ["AB","AC","AD","BC","BD","CD"]
assert combinations(0..<4, 3).collect()*.join() == ["012","013","023","123"]
assert combinationsWithReplacement("ABC", 2).collect()*.join() == ["AA","AB","AC","BB","BC","CC"]

named argument は product で気づいた。
それまでは List で渡していたので Python 風になって満足。

Cyclical Iterators

import static Itertools.*

def raymonds_hamming() {
  def output
  def deferred_output = gen{ yield ->
    for (i in output) yield(i)
  }
  def (result, p2, p3, p5) = tee(deferred_output, 4)
  def m2 = gen{ for (x in p2) it(2*x) }
  def m3 = gen{ for (x in p3) it(3*x) }
  def m5 = gen{ for (x in p5) it(5*x) }
  def merged = merge(m2, m3, m5)
  def combined = chain([1G], merged)
  output = gen{ groupby(combined).each{ k, g -> it(k) } }
  return result
}

assert iter(raymonds_hamming()).take(20).collect() == [1,2,3,4,5,6,8,9,10,12,15,16,18,20,24,25,27,30,32,36]
assert iter(islice(raymonds_hamming(), 1690, 1691)).next() == 2125764000
assert iter(raymonds_hamming())[1690] == 2125764000
// assert iter(islice(raymonds_hamming(), 999999, 1000000)).next() == 519312780448388736089589843750000000000000000000000000000000000000000000000000000000

まわった!
ただ、Python に比べると圧倒的に遅い。
そもそも Python で書かれた仕様のための実装と、最適化された実装の itertools の違いもある。
ただ、以前書いた無限リスト版に比べると速いし1000000番目以外は気にならない。

終わりに

TODO をいくつか

  • 遅い原因の一つはおそらく Thread でそれ自体は仕方が無いことだが、ThreadPool を使用すれば多少改善されると思われる
  • Iterable で返しているので2度 iterator を呼び出すことができるがそのときの振る舞いが正しいのかわからない
def c = count(10)
assert c in Iterable
assert takewhile({ it < 15 }, c).collect() == [10,11,12,13,14]
assert takewhile({ it < 15 }, c).collect() == []  // or [10,11,12,13,14] ?

count の場合 Python でもこのように動作したが、他は確認していない。

  • generator の delegate に何とかすれば yield を引数で渡す必要がなくなるが delegate を塞ぐほどのことではないと判断してやめた。


とりあえず、Python と同じように記述できるようになっただけでうれしい。

2012-04-18 更新

permutations の引数が 0 のときの動作が違ったので修正した

for (x in permutations([1,2,3], 0)) println x
// 修正前
// [1, 2, 3]
// [1, 3, 2]
// [2, 1, 3]
// [2, 3, 1]
// [3, 1, 2]
// [3, 2, 1]

// 修正後
// []