ほくそ笑む

R言語と統計解析について

R でテンソルのインデックス行列のイテレータを書いた

R でテンソルのインデックス行列のイテレータを書く必要があったので書きました。

例えば、2 x 3 の行列のインデックス行列は expand.grid() 関数を使えば次のように簡単に得られます。

> expand.grid(1:2, 1:3)
  Var1 Var2
1    1    1
2    2    1
3    1    2
4    2    2
5    1    3
6    2    3

しかし、2000 x 2000 x 2000 のテンソルのインデックス行列を同じ方法で得ようとすると、メモリが足りずにエラーが出ます。

> expand.grid(1:2000, 1:2000, 1:2000)
エラー:  サイズ 29.8 Gb のベクトルを割り当てることができません

そこで、メモリに載るサイズに分割してインデックス行列を得ることを考えます。 これを実現するイテレータを今回実装しました(関数定義は記事の最後に載せます)。

> iter <- iter_index_matrix(dims = c(2000, 2000, 2000), chunk_size = 1000)
> index_matrix <- iter$nextElem()
> head(index_matrix)
  Var1 Var2 Var3
1    1    1    1
2    2    1    1
3    3    1    1
4    4    1    1
5    5    1    1
6    6    1    1
> nrow(index_matrix)
[1] 1000  # チャンクサイズ分のインデックス行列が得られた

イテレータをチャンクサイズを指定して作成することで、nextElem() メソッドを呼び出すごとにチャンクサイズ分だけのインデックス行列が得られます。

もう一度 nextElem() を呼び出すと次のチャンクサイズ分のインデックス行列が得られます。

> head(iter$nextElem())
  Var1 Var2 Var3
1 1001    1    1
2 1002    1    1
3 1003    1    1
4 1004    1    1
5 1005    1    1
6 1006    1    1
> head(iter$nextElem())
  Var1 Var2 Var3
1    1    2    1
2    2    2    1
3    3    2    1
4    4    2    1
5    5    2    1
6    6    2    1

次のチャンクを持つかどうかを判定するための hasNext() メソッドも持っています。

> iter$hasNext()
[1] TRUE

使い方としては、while ループの中でインデックス行列を処理することを想定しています。

iter <- iter_index_matrix(dims, chunk_size)
while (iter$hasNext()) {
  index_matrix <- iter$nextElem()
  # なんらかの処理
}

最後に、iter_index_matrix() の定義を載せておきます。

iter_index_matrix <- function(dims, chunk_size = 1L) {
  f_split <- cumprod(dims) <= chunk_size
  has_next <- TRUE
  if (all(f_split)) {
    next_element <- function() {
      if (!has_next) stop("StopIteration")
      has_next <<- FALSE
      expand.grid(lapply(dims, seq_len))
    }
  } else {
    split_dims <- split(dims, f_split)
    in_block_dims <- split_dims[["TRUE"]]
    iter_dim <- split_dims[["FALSE"]][1L]
    max_out_block_dims <- split_dims[["FALSE"]][-1L]
    out_block_dims <- rep(1L, length(max_out_block_dims))
    increment_out_block_dims <- function() {
      increment <- function(dims, max_dims) {
        if (length(dims) == 0L) {
          has_next <<- FALSE
          dims <- vector("numeric")
        } else if (dims[1L] < max_dims[1L]) {
          dims[1] <- dims[1] + 1L
        } else {
          dims <- c(1L, increment(dims[-1L], max_dims[-1L]))
        }
        dims
      }
      out_block_dims <<- increment(out_block_dims, max_out_block_dims)
      invisible(NULL)
    }
    step_size <- floor(chunk_size / prod(in_block_dims))
    inds <- c(seq.int(1L, iter_dim, by = step_size), iter_dim + 1L)
    in_block_dims <- lapply(in_block_dims, seq_len)
    i <- 2L
    next_element <- function() {
      if (!has_next) stop("StopIteration")
      start <- inds[i - 1L]
      end <- inds[i] - 1L
      args <- c(in_block_dims, list(start:end), out_block_dims)
      i <<- i + 1L
      if (i > length(inds)) {
        i <<- 2L
        increment_out_block_dims()
      }
      expand.grid(args)
    }
  }
  obj <- list(nextElem = next_element, hasNext = function() has_next)
  class(obj) <- c("ihasNext", "abstractiter", "iter")
  obj
}

このイテレータは iterators パッケージの nextElem() 関数に対応しています。

library(iterators)
index_matrix <- nextElem(iter)

また、itertools パッケージの hasNext() 関数にも対応しています。

library(itertools)
hasNext(iter)

大規模なテンソルのインデックス行列を生成したいときはぜひご活用ください。

Enjoy!