Manjusaka

Manjusaka

為什麼有些時候 Python 中乘法比位運算更快

我本來以為我不再會寫水文了,但是突然發現自己現在也只能勉強寫寫水文才能維持生活這樣子。那就繼續寫水文吧

某天,一個技術群裡老哥提出了這樣一個問題,為什麼在一些情況下,Python 中的簡單乘 / 除法比位運算要慢

首先秉持著實事求是的精神,我們先來驗證一下

In [33]: %timeit 1073741825*2                                                                                                                                                                                                                                                                           
7.47 ns ± 0.0843 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [34]: %timeit 1073741825<<1                                                                                                                                                                                                                                                                          
7.43 ns ± 0.0451 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [35]: %timeit 1073741823<<1                                                                                                                                                                                                                                                                          
7.48 ns ± 0.0621 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [37]: %timeit 1073741823*2                                                                                                                                                                                                                                                                           
7.47 ns ± 0.0564 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

我們發現幾個很有趣的現象

  1. 在值 x<=2^30 時,乘法比直接位運算要快
  2. 在值 x>2^32 時,乘法顯著慢於位運算

這個現象很有趣,那麼這個現象的 root cause 是什麼?實際上這和 Python 底層的實現有關

簡單聊聊#

PyLongObject 的實現#

在 Python 2.x 時期,Python 中將整型分為兩類,一類是 long, 一類是 int 。在 Python3 中這兩者進行了合併。目前在 Python3 中這兩者做了合併,僅剩一個 long

首先來看看 long 這樣一個數據結構底層的實現

struct _longobject {
    PyObject_VAR_HEAD
    digit ob_digit[1];
};

在這裡不用關心,PyObject_VAR_HEAD 的含義,我們只需要關心 ob_digit 即可。

在這裡,ob_digit 是使用了 C99 中的 “柔性數組” 來實現任意長度的整數的存儲。這裡我們可以看一下官方代碼中的文檔

Long integer representation.The absolute value of a number is equal to SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i)
Negative numbers are represented with ob_size < 0; zero is represented by ob_size == 0.
In a normalized number, ob_digit[abs(ob_size)-1] (the most significant digit) is never zero. Also, in all cases, for all valid i,0 <= ob_digit[i] <= MASK.
The allocation function takes care of allocating extra memory so that ob_digit[0] ... ob_digit[abs(ob_size)-1] are actually available.
CAUTION: Generic code manipulating subtypes of PyVarObject has to aware that ints abuse ob_size's sign bit.

簡而言之,Python 是將一個十進制數轉為 2^(SHIFT) 進制數來進行存儲。這裡可能不太好理解。我來舉個例子,在我的電腦上,SHIFT 為 30 ,假設現在有整數 1152921506754330628 ,那麼將其轉為 2^30 進制表示則為: 4*(2^30)^0+2*(2^30)^1+1*(2^30)^2 。那麼此時 ob_digit 是一個含有三個元素的數組,其值為 [4,2,1]

OK,在明白了這樣一些基礎知識後,我們回過頭去看看 Python 中的乘法運算

Python 中的乘法運算#

Python 中的乘法運算,分為兩部分,其中關於大數的乘法,Python 使用了 Karatsuba 算法1,具體實現如下

static PyLongObject *
k_mul(PyLongObject *a, PyLongObject *b)
{
    Py_ssize_t asize = Py_ABS(Py_SIZE(a));
    Py_ssize_t bsize = Py_ABS(Py_SIZE(b));
    PyLongObject *ah = NULL;
    PyLongObject *al = NULL;
    PyLongObject *bh = NULL;
    PyLongObject *bl = NULL;
    PyLongObject *ret = NULL;
    PyLongObject *t1, *t2, *t3;
    Py_ssize_t shift;           /* the number of digits we split off */
    Py_ssize_t i;

    /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
     * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh  + ah*bh + al*bl
     * Then the original product is
     *     ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl
     * By picking X to be a power of 2, "*X" is just shifting, and it's
     * been reduced to 3 multiplies on numbers half the size.
     */

    /* We want to split based on the larger number; fiddle so that b
     * is largest.
     */
    if (asize > bsize) {
        t1 = a;
        a = b;
        b = t1;

        i = asize;
        asize = bsize;
        bsize = i;
    }

    /* Use gradeschool math when either number is too small. */
    i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
    if (asize <= i) {
        if (asize == 0)
            return (PyLongObject *)PyLong_FromLong(0);
        else
            return x_mul(a, b);
    }

    /* If a is small compared to b, splitting on b gives a degenerate
     * case with ah==0, and Karatsuba may be (even much) less efficient
     * than "grade school" then.  However, we can still win, by viewing
     * b as a string of "big digits", each of width a->ob_size.  That
     * leads to a sequence of balanced calls to k_mul.
     */
    if (2 * asize <= bsize)
        return k_lopsided_mul(a, b);

    /* Split a & b into hi & lo pieces. */
    shift = bsize >> 1;
    if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
    assert(Py_SIZE(ah) > 0);            /* the split isn't degenerate */

    if (a == b) {
        bh = ah;
        bl = al;
        Py_INCREF(bh);
        Py_INCREF(bl);
    }
    else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;

    /* The plan:
     * 1. Allocate result space (asize + bsize digits:  that's always
     *    enough).
     * 2. Compute ah*bh, and copy into result at 2*shift.
     * 3. Compute al*bl, and copy into result at 0.  Note that this
     *    can't overlap with #2.
     * 4. Subtract al*bl from the result, starting at shift.  This may
     *    underflow (borrow out of the high digit), but we don't care:
     *    we're effectively doing unsigned arithmetic mod
     *    BASE**(sizea + sizeb), and so long as the *final* result fits,
     *    borrows and carries out of the high digit can be ignored.
     * 5. Subtract ah*bh from the result, starting at shift.
     * 6. Compute (ah+al)*(bh+bl), and add it into the result starting
     *    at shift.
     */

    /* 1. Allocate result space. */
    ret = _PyLong_New(asize + bsize);
    if (ret == NULL) goto fail;
#ifdef Py_DEBUG
    /* Fill with trash, to catch reference to uninitialized digits. */
    memset(ret->ob_digit, 0xDF, Py_SIZE(ret) * sizeof(digit));
#endif

    /* 2. t1 <- ah*bh, and copy into high digits of result. */
    if ((t1 = k_mul(ah, bh)) == NULL) goto fail;
    assert(Py_SIZE(t1) >= 0);
    assert(2*shift + Py_SIZE(t1) <= Py_SIZE(ret));
    memcpy(ret->ob_digit + 2*shift, t1->ob_digit,
           Py_SIZE(t1) * sizeof(digit));

    /* Zero-out the digits higher than the ah*bh copy. */
    i = Py_SIZE(ret) - 2*shift - Py_SIZE(t1);
    if (i)
        memset(ret->ob_digit + 2*shift + Py_SIZE(t1), 0,
               i * sizeof(digit));

    /* 3. t2 <- al*bl, and copy into the low digits. */
    if ((t2 = k_mul(al, bl)) == NULL) {
        Py_DECREF(t1);
        goto fail;
    }
    assert(Py_SIZE(t2) >= 0);
    assert(Py_SIZE(t2) <= 2*shift); /* no overlap with high digits */
    memcpy(ret->ob_digit, t2->ob_digit, Py_SIZE(t2) * sizeof(digit));

    /* Zero out remaining digits. */
    i = 2*shift - Py_SIZE(t2);          /* number of uninitialized digits */
    if (i)
        memset(ret->ob_digit + Py_SIZE(t2), 0, i * sizeof(digit));

    /* 4 & 5. Subtract ah*bh (t1) and al*bl (t2).  We do al*bl first
     * because it's fresher in cache.
     */
    i = Py_SIZE(ret) - shift;  /* # digits after shift */
    (void)v_isub(ret->ob_digit + shift, i, t2->ob_digit, Py_SIZE(t2));
    Py_DECREF(t2);

    (void)v_isub(ret->ob_digit + shift, i, t1->ob_digit, Py_SIZE(t1));
    Py_DECREF(t1);

    /* 6. t3 <- (ah+al)(bh+bl), and add into result. */
    if ((t1 = x_add(ah, al)) == NULL) goto fail;
    Py_DECREF(ah);
    Py_DECREF(al);
    ah = al = NULL;

    if (a == b) {
        t2 = t1;
        Py_INCREF(t2);
    }
    else if ((t2 = x_add(bh, bl)) == NULL) {
        Py_DECREF(t1);
        goto fail;
    }
    Py_DECREF(bh);
    Py_DECREF(bl);
    bh = bl = NULL;

    t3 = k_mul(t1, t2);
    Py_DECREF(t1);
    Py_DECREF(t2);
    if (t3 == NULL) goto fail;
    assert(Py_SIZE(t3) >= 0);

    /* Add t3.  It's not obvious why we can't run out of room here.
     * See the (*) comment after this function.
     */
    (void)v_iadd(ret->ob_digit + shift, i, t3->ob_digit, Py_SIZE(t3));
    Py_DECREF(t3);

    return long_normalize(ret);

  fail:
    Py_XDECREF(ret);
    Py_XDECREF(ah);
    Py_XDECREF(al);
    Py_XDECREF(bh);
    Py_XDECREF(bl);
    return NULL;
}

這裡不對 Karatsuba 算法1 的實現做單獨解釋,有興趣的朋友可以參考文末的 reference 去了解具體的詳情。

在普通情況下,普通乘法的時間複雜度位 n^2 (n 為位數),而 K 算法的時間複雜度為 3n^(log3) ≈ 3n^1.585 ,看起來 K 算法的性能要優於普通乘法,那麼為什麼 Python 不全部使用 K 算法呢?

很簡單,K 算法的優勢實際上要在當 n 足夠大的時候,才會對普通乘法形成優勢。同時考慮到內存訪問等因素,當 n 不夠大時,實際上採用 K 算法的性能將差於直接進行乘法。

所以我們來看看 Python 中乘法的實現

static PyObject *
long_mul(PyLongObject *a, PyLongObject *b)
{
    PyLongObject *z;

    CHECK_BINOP(a, b);

    /* fast path for single-digit multiplication */
    if (Py_ABS(Py_SIZE(a)) <= 1 && Py_ABS(Py_SIZE(b)) <= 1) {
        stwodigits v = (stwodigits)(MEDIUM_VALUE(a)) * MEDIUM_VALUE(b);
        return PyLong_FromLongLong((long long)v);
    }

    z = k_mul(a, b);
    /* Negate if exactly one of the inputs is negative. */
    if (((Py_SIZE(a) ^ Py_SIZE(b)) < 0) && z) {
        _PyLong_Negate(&z);
        if (z == NULL)
            return NULL;
    }
    return (PyObject *)z;
}

在這裡我們看到,當兩個數皆小於 2^30-1 時,Python 將直接使用普通乘法並返回,否則將使用 K 算法進行計算

這個時候,我們來看一下位運算的實現,以右移為例


static PyObject *
long_rshift(PyObject *a, PyObject *b)
{
    Py_ssize_t wordshift;
    digit remshift;

    CHECK_BINOP(a, b);

    if (Py_SIZE(b) < 0) {
        PyErr_SetString(PyExc_ValueError, "negative shift count");
        return NULL;
    }
    if (Py_SIZE(a) == 0) {
        return PyLong_FromLong(0);
    }
    if (divmod_shift(b, &wordshift, &remshift) < 0)
        return NULL;
    return long_rshift1((PyLongObject *)a, wordshift, remshift);
}

static PyObject *
long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
{
    PyLongObject *z = NULL;
    Py_ssize_t newsize, hishift, i, j;
    digit lomask, himask;

    if (Py_SIZE(a) < 0) {
        /* Right shifting negative numbers is harder */
        PyLongObject *a1, *a2;
        a1 = (PyLongObject *) long_invert(a);
        if (a1 == NULL)
            return NULL;
        a2 = (PyLongObject *) long_rshift1(a1, wordshift, remshift);
        Py_DECREF(a1);
        if (a2 == NULL)
            return NULL;
        z = (PyLongObject *) long_invert(a2);
        Py_DECREF(a2);
    }
    else {
        newsize = Py_SIZE(a) - wordshift;
        if (newsize <= 0)
            return PyLong_FromLong(0);
        hishift = PyLong_SHIFT - remshift;
        lomask = ((digit)1 << hishift) - 1;
        himask = PyLong_MASK ^ lomask;
        z = _PyLong_New(newsize);
        if (z == NULL)
            return NULL;
        for (i = 0, j = wordshift; i < newsize; i++, j++) {
            z->ob_digit[i] = (a->ob_digit[j] >> remshift) & lomask;
            if (i+1 < newsize)
                z->ob_digit[i] |= (a->ob_digit[j+1] << hishift) & himask;
        }
        z = maybe_small_long(long_normalize(z));
    }
    return (PyObject *)z;
}

在這裡我們能看到,在兩側都是小數的情況下,位移運算算法將比普通乘法,存在更多的內存分配等操作。這樣也會回答了我們文初所提到的問題,“為什麼一些時候乘法比位運算更快”。

總結#

本文差不多就到這裡了,實際上通過這次分析我們能得到一些很有趣但是也很冷門的知識。實際上我們目前看到這樣一個結果,是 Python 對於我們常見且高頻的操作所做的一個特定的設計。而這也提醒我們,Python 實際上對於很多操作都存在自己內建的設計哲學,在日常使用的時候,其餘語言的經驗,可能無法復用

差不多就這樣吧,只能勉強寫水文苟活了(逃

Reference#

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。