Manjusaka

Manjusaka

なぜ時々 Python での乗算がビット演算よりも速いのか

私はもう水文を書くことはないと思っていましたが、突然、自分が今は生活を維持するために水文を書かなければならないことに気づきました。それなら、水文を書き続けましょう。

ある日、技術系のグループで兄貴がこんな質問をしました。なぜある場合において、Python の単純な乗算 / 除算がビット演算よりも遅いのか。

まずは実事求是の精神を持って、検証してみましょう。

In [33]: %timeit 1073741825*2                                                                                                                                                                                                                                                                           
7.47 ns ± 0.0843 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [34]: %timeit 1073741825<<1                                                                                                                                                                                                                                                                          
7.43 ns ± 0.0451 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [35]: %timeit 1073741823<<1                                                                                                                                                                                                                                                                          
7.48 ns ± 0.0621 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [37]: %timeit 1073741823*2                                                                                                                                                                                                                                                                           
7.47 ns ± 0.0564 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

いくつかの興味深い現象が見つかりました。

  1. x<=2^30 の場合、乗算は直接のビット演算よりも速い。
  2. x>2^32 の場合、乗算はビット演算よりも明らかに遅い。

この現象は興味深いです。この現象の root cause は何でしょうか?実際、これは Python の底層の実装に関係しています。

簡単に話す#

PyLongObject の実装#

Python 2.x の時代、Python では整数型を二つに分けていました。一つは long、もう一つは int です。Python3 ではこれらは統合され、現在は long のみが残っています。

まず、long というデータ構造の底層の実装を見てみましょう。

struct _longobject {
    PyObject_VAR_HEAD
    digit ob_digit[1];
};

ここでは、PyObject_VAR_HEAD の意味を気にする必要はありません。私たちが気にするのは ob_digit です。

ここで、ob_digit は C99 の「柔軟な配列」を使用して任意の長さの整数を格納するために実装されています。公式コードのドキュメントを見てみましょう。

長整数の表現。数の絶対値は SUM (for i=0 through abs (ob_size)-1) ob_digit [i] * 2**(SHIFT*i) に等しい。負の数は ob_size < 0 で表現され、ゼロは ob_size == 0 で表現される。正規化された数では、ob_digit [abs (ob_size)-1](最上位桁)は決してゼロではない。また、すべての有効な i に対して、0 <= ob_digit [i] <= MASK である。割り当て関数は、ob_digit [0] ... ob_digit [abs (ob_size)-1] が実際に利用可能であるように追加のメモリを割り当てる。注意: PyVarObject のサブタイプを操作する一般的なコードは、整数が ob_size の符号ビットを悪用していることを認識する必要がある。

要するに、Python は十進数を 2^(SHIFT) 進数に変換して格納しています。ここは少し理解しにくいかもしれません。例を挙げると、私のコンピュータでは SHIFT が 30 です。整数 1152921506754330628 があると仮定すると、これを 2^30 進数で表すと次のようになります: 4*(2^30)^0+2*(2^30)^1+1*(2^30)^2。この時、ob_digit は三つの要素を持つ配列で、その値は [4,2,1] です。

OK、こうした基礎知識を理解した後、Python における乗算の実装を見てみましょう。

Python における乗算#

Python の乗算は二つの部分に分かれています。そのうち大きな数の乗算について、Python は Karatsuba アルゴリズム1 を使用しています。具体的な実装は以下の通りです。

static PyLongObject *
k_mul(PyLongObject *a, PyLongObject *b)
{
    Py_ssize_t asize = Py_ABS(Py_SIZE(a));
    Py_ssize_t bsize = Py_ABS(Py_SIZE(b));
    PyLongObject *ah = NULL;
    PyLongObject *al = NULL;
    PyLongObject *bh = NULL;
    PyLongObject *bl = NULL;
    PyLongObject *ret = NULL;
    PyLongObject *t1, *t2, *t3;
    Py_ssize_t shift;           /* 分割する桁数 */
    Py_ssize_t i;

    /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
     * k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl
     * 元の積は
     *     ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl
     * X を 2 の累乗に選ぶことで、"*X" は単にシフトであり、
     * サイズが半分の数に対して 3 回の乗算に減らされます。
     */

    /* より大きな数に基づいて分割したい。b が最大になるように調整する。 */
    if (asize > bsize) {
        t1 = a;
        a = b;
        b = t1;

        i = asize;
        asize = bsize;
        bsize = i;
    }

    /* どちらかの数が小さすぎる場合は、通常の算数を使用します。 */
    i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
    if (asize <= i) {
        if (asize == 0)
            return (PyLongObject *)PyLong_FromLong(0);
        else
            return x_mul(a, b);
    }

    /* a が b に比べて小さい場合、b で分割すると ah==0 という退化したケースになり、
     * Karatsuba は「グレードスクール」よりも(さらに)効率が悪くなる可能性があります。
     * しかし、b を「大きな桁」の文字列として見ることで、k_mul へのバランスの取れた呼び出しの
     * シーケンスを得ることができます。
     */
    if (2 * asize <= bsize)
        return k_lopsided_mul(a, b);

    /* a と b を hi と lo の部分に分割します。 */
    shift = bsize >> 1;
    if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
    assert(Py_SIZE(ah) > 0);            /* 分割が退化していないことを確認 */

    if (a == b) {
        bh = ah;
        bl = al;
        Py_INCREF(bh);
        Py_INCREF(bl);
    }
    else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;

    /* 計画:
     * 1. 結果のスペースを割り当てる(asize + bsize 桁: それは常に十分です)。
     * 2. ah*bh を計算し、結果の 2*shift にコピーします。
     * 3. al*bl を計算し、結果の 0 にコピーします。これは #2 と重ならないことに注意してください。
     * 4. al*bl を結果から引きます。これはアンダーフローする可能性がありますが、気にしません。
     *     実際には BASE**(sizea + sizeb) で符号なしの算術を行っており、
     *     最終的な結果が収まる限り、高桁からの借用や繰り上がりは無視できます。
     * 5. ah*bh を結果から引きます。
     * 6. (ah+al)*(bh+bl) を計算し、結果に shift から追加します。
     */

    /* 1. 結果のスペースを割り当てます。 */
    ret = _PyLong_New(asize + bsize);
    if (ret == NULL) goto fail;
#ifdef Py_DEBUG
    /* ゴミで埋めて、初期化されていない桁への参照をキャッチします。 */
    memset(ret->ob_digit, 0xDF, Py_SIZE(ret) * sizeof(digit));
#endif

    /* 2. t1 <- ah*bh を計算し、高桁の結果にコピーします。 */
    if ((t1 = k_mul(ah, bh)) == NULL) goto fail;
    assert(Py_SIZE(t1) >= 0);
    assert(2*shift + Py_SIZE(t1) <= Py_SIZE(ret));
    memcpy(ret->ob_digit + 2*shift, t1->ob_digit,
           Py_SIZE(t1) * sizeof(digit));

    /* ah*bh コピーよりも高い桁をゼロにします。 */
    i = Py_SIZE(ret) - 2*shift - Py_SIZE(t1);
    if (i)
        memset(ret->ob_digit + 2*shift + Py_SIZE(t1), 0,
               i * sizeof(digit));

    /* 3. t2 <- al*bl を計算し、低桁にコピーします。 */
    if ((t2 = k_mul(al, bl)) == NULL) {
        Py_DECREF(t1);
        goto fail;
    }
    assert(Py_SIZE(t2) >= 0);
    assert(Py_SIZE(t2) <= 2*shift); /* 高桁との重複はありません */
    memcpy(ret->ob_digit, t2->ob_digit, Py_SIZE(t2) * sizeof(digit));

    /* 残りの桁をゼロにします。 */
    i = 2*shift - Py_SIZE(t2);          /* 初期化されていない桁の数 */
    if (i)
        memset(ret->ob_digit + Py_SIZE(t2), 0, i * sizeof(digit));

    /* 4 & 5. ah*bh (t1) と al*bl (t2) を引きます。最初に al*bl を行います。
     * それはキャッシュに新しいからです。
     */
    i = Py_SIZE(ret) - shift;  /* シフト後の桁数 */
    (void)v_isub(ret->ob_digit + shift, i, t2->ob_digit, Py_SIZE(t2));
    Py_DECREF(t2);

    (void)v_isub(ret->ob_digit + shift, i, t1->ob_digit, Py_SIZE(t1));
    Py_DECREF(t1);

    /* 6. t3 <- (ah+al)(bh+bl) を計算し、結果に追加します。 */
    if ((t1 = x_add(ah, al)) == NULL) goto fail;
    Py_DECREF(ah);
    Py_DECREF(al);
    ah = al = NULL;

    if (a == b) {
        t2 = t1;
        Py_INCREF(t2);
    }
    else if ((t2 = x_add(bh, bl)) == NULL) {
        Py_DECREF(t1);
        goto fail;
    }
    Py_DECREF(bh);
    Py_DECREF(bl);
    bh = bl = NULL;

    t3 = k_mul(t1, t2);
    Py_DECREF(t1);
    Py_DECREF(t2);
    if (t3 == NULL) goto fail;
    assert(Py_SIZE(t3) >= 0);

    /* t3 を追加します。ここでなぜここでスペースが足りなくならないのかは明らかではありません。
     * この関数の後の (*) コメントを参照してください。
     */
    (void)v_iadd(ret->ob_digit + shift, i, t3->ob_digit, Py_SIZE(t3));
    Py_DECREF(t3);

    return long_normalize(ret);

  fail:
    Py_XDECREF(ret);
    Py_XDECREF(ah);
    Py_XDECREF(al);
    Py_XDECREF(bh);
    Py_XDECREF(bl);
    return NULL;
}

ここでは Karatsuba アルゴリズム1 の実装については別途説明しません。興味がある方は文末のリファレンスを参照して具体的な詳細を理解してください。

通常の場合、通常の乗算の時間計算量は n^2 (n は桁数) ですが、K アルゴリズムの時間計算量は 3n^(log3) ≈ 3n^1.585 です。K アルゴリズムの性能は通常の乗算よりも優れているように見えますが、なぜ Python はすべての計算に K アルゴリズムを使用しないのでしょうか?

それは簡単です。K アルゴリズムの利点は、実際には n が十分大きいときにのみ、通常の乗算に対して優位性を持つからです。また、メモリアクセスなどの要因を考慮すると、n が十分大きくない場合、実際には K アルゴリズムの性能は直接の乗算よりも劣ります。

それでは、Python における乗算の実装を見てみましょう。

static PyObject *
long_mul(PyLongObject *a, PyLongObject *b)
{
    PyLongObject *z;

    CHECK_BINOP(a, b);

    /* 一桁の乗算のための高速パス */
    if (Py_ABS(Py_SIZE(a)) <= 1 && Py_ABS(Py_SIZE(b)) <= 1) {
        stwodigits v = (stwodigits)(MEDIUM_VALUE(a)) * MEDIUM_VALUE(b);
        return PyLong_FromLongLong((long long)v);
    }

    z = k_mul(a, b);
    /* 入力のうち一つだけが負の場合は符号を反転します。 */
    if (((Py_SIZE(a) ^ Py_SIZE(b)) < 0) && z) {
        _PyLong_Negate(&z);
        if (z == NULL)
            return NULL;
    }
    return (PyObject *)z;
}

ここでは、二つの数が共に 2^30-1 より小さい場合、Python は直接通常の乗算を使用し、それ以外の場合は K アルゴリズムを使用して計算します。

この時、ビット演算の実装を見てみましょう。右シフトの例を挙げます。

static PyObject *
long_rshift(PyObject *a, PyObject *b)
{
    Py_ssize_t wordshift;
    digit remshift;

    CHECK_BINOP(a, b);

    if (Py_SIZE(b) < 0) {
        PyErr_SetString(PyExc_ValueError, "negative shift count");
        return NULL;
    }
    if (Py_SIZE(a) == 0) {
        return PyLong_FromLong(0);
    }
    if (divmod_shift(b, &wordshift, &remshift) < 0)
        return NULL;
    return long_rshift1((PyLongObject *)a, wordshift, remshift);
}

static PyObject *
long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
{
    PyLongObject *z = NULL;
    Py_ssize_t newsize, hishift, i, j;
    digit lomask, himask;

    if (Py_SIZE(a) < 0) {
        /* 負の数を右シフトするのは難しい */
        PyLongObject *a1, *a2;
        a1 = (PyLongObject *) long_invert(a);
        if (a1 == NULL)
            return NULL;
        a2 = (PyLongObject *) long_rshift1(a1, wordshift, remshift);
        Py_DECREF(a1);
        if (a2 == NULL)
            return NULL;
        z = (PyLongObject *) long_invert(a2);
        Py_DECREF(a2);
    }
    else {
        newsize = Py_SIZE(a) - wordshift;
        if (newsize <= 0)
            return PyLong_FromLong(0);
        hishift = PyLong_SHIFT - remshift;
        lomask = ((digit)1 << hishift) - 1;
        himask = PyLong_MASK ^ lomask;
        z = _PyLong_New(newsize);
        if (z == NULL)
            return NULL;
        for (i = 0, j = wordshift; i < newsize; i++, j++) {
            z->ob_digit[i] = (a->ob_digit[j] >> remshift) & lomask;
            if (i+1 < newsize)
                z->ob_digit[i] |= (a->ob_digit[j+1] << hishift) & himask;
        }
        z = maybe_small_long(long_normalize(z));
    }
    return (PyObject *)z;
}

ここでは、両側が小数の場合、ビットシフトアルゴリズムは通常の乗算よりも多くのメモリ割り当てなどの操作が存在します。これにより、文の初めに提起した「なぜ時々乗算がビット演算よりも速いのか」という問題に答えることができます。

まとめ#

この記事はここまでです。実際、今回の分析を通じて、非常に興味深くもあまり知られていない知識を得ることができました。実際、私たちが現在見ている結果は、Python が一般的で頻繁な操作に対して特定の設計を行ったものです。また、これは Python が多くの操作に対して独自の設計哲学を持っていることを私たちに思い出させます。日常の使用において、他の言語の経験は再利用できないかもしれません。

それでは、これで終わりです。水文を書いて生き延びるのがやっとです(逃げ)。

リファレンス#

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。