Python(特にNumPy)の配列操作

紹介

NumPyの配列操作の仕組みについて紹介します。生のPythonよりもできる操作が多いのでこちらで統一します。 さて、NumPyで

a = [1,...,n]
b = a[::-1]

としたとき、2行目の操作は O(1) O(n)かご存知ですか?まあこれは O(1)です。他にもNumPyは配列に関する多くの操作が O(1)でできるよう配列のデータの管理が工夫されています。この記事ではいくつか具体例をあげて説明していきます。

知っている人向けに書くと、shapeとstridesの関係を書くだけです。ここらへんのattributesは開発者なら毎日のように使っていそうだけど一般ユーザはあんまり知らなそう(特にstrides)

shapeとstrides

多くのプログラミング言語で配列のデータは連続領域に確保されます。すなわち、a[k]a[k+1]は隣り合ったメモリ領域に確保されます。しかし、NumPyの配列は必ずしも個々のデータを連続領域にもちません。NumPyの配列は、data, shape, stridesという3要素によって管理されます。

data

データの先頭のポインタです。

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> a.data
<memory at 0x110417948>

shape

データの多次元サイズ情報です

>>> a.shape
(10,)

strides

データアクセスのindexが1増えるとメモリ領域が何byteずれるかを表します。

>>> a.strides
(8,)

これは、配列anumpy.int64の型なので、a[k]a[k+1]が8byteずれていることを表します。numpy.int32型に変換してみると、stridesが変わることが確認できます。

>>> a.astype(np.int32).strides
(4,)

多次元配列の場合の例も見てみましょう。

>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> a.shape
(2, 3)
>>> a.strides
(24, 8)

これは、この場合データがメモリ領域に連続して0,1,2,3,4,5と並んでおり、二次元配列の[i,j]番目の要素が連続領域の3*i + j番目の要素に対応するからです。このときjを一つずらすと、int64 1つ分の8byteがずれ、iを一つずらすとint64 3つ分の24byteがずれます。

要素のアクセス方法

例として3次元配列で説明します。a[i][j][k]にアクセスするためには、メモリのどこを見ればいいでしょうか?

先程の例と同じことですが、C言語風に答えを書くと*(a + strides[0] * i + strides[1] * j + strides[2] * k)です。

a[::-1]の表現方法

さて、ここまで説明をして、stridesいらなくないかと感じた人が多いと思います。実際、データが連続している場合、 \mathrm{strides}[k] = データサイズ * \prod_{i = k + 1}^{n} \mathrm{shape}[i]が成り立つからです。

自分でa = [1,2,3]にように配列を宣言する場合や、np.arangeなどによって宣言する場合についてこれは正しいですが、適宜shapestridesを変えることで同じデータをもとに幅広い配列を表現できます。いくつか例を紹介していきます。

たとえばa[::-1]のstridesを見てみましょう。

>>> a = np.arange(10)
>>> a.strides
(8,)
>>> a[::-1].strides
(-8,)

a[::-1]はメモリ上のデータ配置を変えずにstridesのみを変えることによってメモリ上のデータを後ろ向きに走査しています。データは変わらないので配列の宣言時に新しいメモリの確保は行われず O(1)です。

broadcastの表現方法

NumPyの嬉しさの一つとしてブロードキャストがあります。例えば以下のような演算ができます。

>>> a = np.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> b = np.arange(3)
>>> b
array([0, 1, 2])
>>> a + b
array([[0, 2, 4],
       [3, 5, 7]])

このとき、内部的に一度baと同じサイズに拡張して計算しています(np.broadcast_to)。

>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> b
array([0, 1, 2])
>>> c = np.broadcast_to(b, (2,3))
>>> c
array([[0, 1, 2],
       [0, 1, 2]])
>>> a + c
array([[0, 2, 4],
       [3, 5, 7]])

この拡張もstridesを変更するだけで済むので O(1)です。

>>> b.strides
(8,)
>>> c.strides
(0, 8)

言われれば当たり前ですが、対応するstridesを0にすると、そのindexを変えてもアクセス位置が変わらないということであり、実質次元を拡張するような操作になっています。

transposeの表現方法

配列のtranspose(軸の入れ替え)も実は非常に簡単に表現できます。

>>> a = np.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> a.strides
(24, 8)
>>> a.T
array([[0, 3],
       [1, 4],
       [2, 5]])
>>> a.T.strides
(8, 24)

このように、strides(とshape)の要素の順序を入れ替えるだけです。メモリ領域には触らないので O(1)です。参照渡しは最高ですね〜〜〜

もうちょっと詳しく

普通に読み飛ばしてもらって結構です。

transposeというと2次元配列の行と列を入れ替えるイメージが強いと思いますが、numpyでのtransposeはもっと幅広く、一般次元の配列(テンソル)について、次元の順序を任意に入れ替える操作のことを指します。

配列の中身は省略しますが

>>> a = np.arange(12).reshape(2,2,3)
>>> b = np.transpose(a, (1,0,2))

のように扱えます。np.transposeの第2引数には(0, 1, ..., len(a) - 1)のpermutationが入ります。

これはなかなか便利で、たとえばnp.sumのような演算(特定の軸について潰すような演算)の実装は、配列の初めのk軸を潰すような演算として実装されていますが、np.sum(a, axis=1)のような演算が来たとしても、初めに軸を回転させて1番目のaxisを0番目に持ってくることにより、そうした実装を適用することが出来ます。transposeはこうしたリダクション演算の実装には欠かせない機能になっています。

indexing

先程のセクションに比べてちょっとレベルが下がります。

たとえば要素を2つおきにとってきた配列を作るとき、a[::2]としますよね。これもstridesを変えるだけで行えるので O(1)で生成できます。

>>> a = np.arange(10)
>>> a[::2].strides
(16,)

応用編

とりあえずここまででも、計算に便利な配列を O(1)で生成できるうれしさはまあまあ伝わっていると思いますが、一つ本質的に役に立つ例を紹介します。

多次元配列のindexアクセスではa[i,j,k]要素にアクセスするのにsum(a.strides * numpy.array([i, j, k]))のoffsetだけデータのポインタからずれた値を見る、と言いましたが、配列のアクセスのたびにこのような演算をしていてはオーバーヘッドが非常に大きくなってしまいます。そこで、配列の次元圧縮というものを考えます。

そのために、データがcontiguousかどうか、というフラグを考えます。これはnp.ndarray.data.contiguousに対応します。これは、データがメモリ上で連続しているか、つまりa[i][j][k](i,j,k)の辞書順に見た時、すべての要素が等間隔で並んでいるか、のフラグを指します。

>>> a = np.arange(10).reshape(2,5)
>>> a
array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])
>>> a.data.contiguous
True
>>> b = a[:,::2]
>>> b
array([[0, 2, 4],
       [5, 7, 9]])
>>> b.strides
(40, 16)
>>> b.data.contiguous
False

例えば上のように生成した配列bは(0,2)番目と(1,0)番目の要素が8byteしか離れていないため(ほかは16byteだけど)contiguousではありません。

さて、配列の次元圧縮とは、配列がcontiguousの場合に1次元の配列に潰すことです(より広義に、配列の一部の次元がcontiguousである場合にその次元を潰す、ということもあります)。

1次元の配列のindexアクセスは高速なので、実質1次元配列の多次元配列に関しては1次元配列のようにアクセスしてしまおうということです。

とくに要素ごとの足し算(np.ndarrayに対するa + b)をイメージしてもらえると、a, bがcontiguousな場合に計算が楽になることがわかると思います。

応用編(hoge)

実はnp.ndarray.stridesはreadonlyではないので自分で勝手に書き換えて遊ぶことができます。

>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> a.strides
(8,)
>>> a.strides = (4,)
>>> a
array([          0,  4294967296,           1,  8589934592,           2,
       12884901888,           3, 17179869184,           4, 21474836480])

とくに応用することはなさそうですけど遊んでみると楽しいです。メモリを生で触っている感覚を楽しめるかもしれないです。