Manjusaka

Manjusaka

Why multiplication is sometimes faster than bitwise operations in Python

I originally thought I wouldn't write filler articles anymore, but I suddenly found that I can only barely write filler articles to make a living. So let's continue writing filler articles.

One day, a brother in a tech group raised a question: why in some cases, simple multiplication/division in Python is slower than bitwise operations.

First, adhering to the spirit of seeking truth from facts, let's verify this.

In [33]: %timeit 1073741825*2                                                                                                                                                                                                                                                                           
7.47 ns ± 0.0843 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [34]: %timeit 1073741825<<1                                                                                                                                                                                                                                                                          
7.43 ns ± 0.0451 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [35]: %timeit 1073741823<<1                                                                                                                                                                                                                                                                          
7.48 ns ± 0.0621 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

In [37]: %timeit 1073741823*2                                                                                                                                                                                                                                                                           
7.47 ns ± 0.0564 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

We found several interesting phenomena:

  1. When the value x <= 2^30, multiplication is faster than direct bitwise operations.
  2. When the value x > 2^32, multiplication is significantly slower than bitwise operations.

This phenomenon is interesting, so what is the root cause of this phenomenon? In fact, it is related to the underlying implementation of Python.

A Brief Discussion#

Implementation of PyLongObject#

In the Python 2.x era, Python divided integers into two categories: long and int. In Python 3, these two were merged. Currently, in Python 3, there is only one long type.

First, let's take a look at the underlying implementation of the long data structure.

struct _longobject {
    PyObject_VAR_HEAD
    digit ob_digit[1];
};

Here, we don't need to worry about the meaning of PyObject_VAR_HEAD, we only need to focus on ob_digit.

Here, ob_digit uses a "flexible array" from C99 to implement the storage of integers of arbitrary length. We can look at the documentation in the official code.

Long integer representation. The absolute value of a number is equal to SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i)
Negative numbers are represented with ob_size < 0; zero is represented by ob_size == 0.
In a normalized number, ob_digit[abs(ob_size)-1] (the most significant digit) is never zero. Also, in all cases, for all valid i, 0 <= ob_digit[i] <= MASK.
The allocation function takes care of allocating extra memory so that ob_digit[0] ... ob_digit[abs(ob_size)-1] are actually available.
CAUTION: Generic code manipulating subtypes of PyVarObject has to be aware that ints abuse ob_size's sign bit.

In short, Python converts a decimal number into a 2^(SHIFT) base number for storage. This might be a bit hard to understand. Let me give an example. On my computer, SHIFT is 30. Suppose we have the integer 1152921506754330628, then converting it to base 2^30 representation would be: 4*(2^30)^0 + 2*(2^30)^1 + 1*(2^30)^2. At this point, ob_digit is an array with three elements, with values [4, 2, 1].

OK, after understanding this basic knowledge, let's turn back to the multiplication operation in Python.

Multiplication in Python#

The multiplication operation in Python is divided into two parts. For large number multiplication, Python uses the Karatsuba algorithm1, and the specific implementation is as follows:

static PyLongObject *
k_mul(PyLongObject *a, PyLongObject *b)
{
    Py_ssize_t asize = Py_ABS(Py_SIZE(a));
    Py_ssize_t bsize = Py_ABS(Py_SIZE(b));
    PyLongObject *ah = NULL;
    PyLongObject *al = NULL;
    PyLongObject *bh = NULL;
    PyLongObject *bl = NULL;
    PyLongObject *ret = NULL;
    PyLongObject *t1, *t2, *t3;
    Py_ssize_t shift;           /* the number of digits we split off */
    Py_ssize_t i;

    /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
     * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh  + ah*bh + al*bl
     * Then the original product is
     *     ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl
     * By picking X to be a power of 2, "*X" is just shifting, and it's
     * been reduced to 3 multiplies on numbers half the size.
     */

    /* We want to split based on the larger number; fiddle so that b
     * is largest.
     */
    if (asize > bsize) {
        t1 = a;
        a = b;
        b = t1;

        i = asize;
        asize = bsize;
        bsize = i;
    }

    /* Use gradeschool math when either number is too small. */
    i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
    if (asize <= i) {
        if (asize == 0)
            return (PyLongObject *)PyLong_FromLong(0);
        else
            return x_mul(a, b);
    }

    /* If a is small compared to b, splitting on b gives a degenerate
     * case with ah==0, and Karatsuba may be (even much) less efficient
     * than "grade school" then.  However, we can still win, by viewing
     * b as a string of "big digits", each of width a->ob_size.  That
     * leads to a sequence of balanced calls to k_mul.
     */
    if (2 * asize <= bsize)
        return k_lopsided_mul(a, b);

    /* Split a & b into hi & lo pieces. */
    shift = bsize >> 1;
    if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
    assert(Py_SIZE(ah) > 0);            /* the split isn't degenerate */

    if (a == b) {
        bh = ah;
        bl = al;
        Py_INCREF(bh);
        Py_INCREF(bl);
    }
    else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;

    /* The plan:
     * 1. Allocate result space (asize + bsize digits:  that's always
     *    enough).
     * 2. Compute ah*bh, and copy into result at 2*shift.
     * 3. Compute al*bl, and copy into result at 0.  Note that this
     *    can't overlap with #2.
     * 4. Subtract al*bl from the result, starting at shift.  This may
     *    underflow (borrow out of the high digit), but we don't care:
     *    we're effectively doing unsigned arithmetic mod
     *    BASE**(sizea + sizeb), and so long as the *final* result fits,
     *    borrows and carries out of the high digit can be ignored.
     * 5. Subtract ah*bh from the result, starting at shift.
     * 6. Compute (ah+al)*(bh+bl), and add it into the result starting
     *    at shift.
     */

    /* 1. Allocate result space. */
    ret = _PyLong_New(asize + bsize);
    if (ret == NULL) goto fail;
#ifdef Py_DEBUG
    /* Fill with trash, to catch reference to uninitialized digits. */
    memset(ret->ob_digit, 0xDF, Py_SIZE(ret) * sizeof(digit));
#endif

    /* 2. t1 <- ah*bh, and copy into high digits of result. */
    if ((t1 = k_mul(ah, bh)) == NULL) goto fail;
    assert(Py_SIZE(t1) >= 0);
    assert(2*shift + Py_SIZE(t1) <= Py_SIZE(ret));
    memcpy(ret->ob_digit + 2*shift, t1->ob_digit,
           Py_SIZE(t1) * sizeof(digit));

    /* Zero-out the digits higher than the ah*bh copy. */
    i = Py_SIZE(ret) - 2*shift - Py_SIZE(t1);
    if (i)
        memset(ret->ob_digit + 2*shift + Py_SIZE(t1), 0,
               i * sizeof(digit));

    /* 3. t2 <- al*bl, and copy into the low digits. */
    if ((t2 = k_mul(al, bl)) == NULL) {
        Py_DECREF(t1);
        goto fail;
    }
    assert(Py_SIZE(t2) >= 0);
    assert(Py_SIZE(t2) <= 2*shift); /* no overlap with high digits */
    memcpy(ret->ob_digit, t2->ob_digit, Py_SIZE(t2) * sizeof(digit));

    /* Zero out remaining digits. */
    i = 2*shift - Py_SIZE(t2);          /* number of uninitialized digits */
    if (i)
        memset(ret->ob_digit + Py_SIZE(t2), 0, i * sizeof(digit));

    /* 4 & 5. Subtract ah*bh (t1) and al*bl (t2).  We do al*bl first
     * because it's fresher in cache.
     */
    i = Py_SIZE(ret) - shift;  /* # digits after shift */
    (void)v_isub(ret->ob_digit + shift, i, t2->ob_digit, Py_SIZE(t2));
    Py_DECREF(t2);

    (void)v_isub(ret->ob_digit + shift, i, t1->ob_digit, Py_SIZE(t1));
    Py_DECREF(t1);

    /* 6. t3 <- (ah+al)(bh+bl), and add into result. */
    if ((t1 = x_add(ah, al)) == NULL) goto fail;
    Py_DECREF(ah);
    Py_DECREF(al);
    ah = al = NULL;

    if (a == b) {
        t2 = t1;
        Py_INCREF(t2);
    }
    else if ((t2 = x_add(bh, bl)) == NULL) {
        Py_DECREF(t1);
        goto fail;
    }
    Py_DECREF(bh);
    Py_DECREF(bl);
    bh = bl = NULL;

    t3 = k_mul(t1, t2);
    Py_DECREF(t1);
    Py_DECREF(t2);
    if (t3 == NULL) goto fail;
    assert(Py_SIZE(t3) >= 0);

    /* Add t3.  It's not obvious why we can't run out of room here.
     * See the (*) comment after this function.
     */
    (void)v_iadd(ret->ob_digit + shift, i, t3->ob_digit, Py_SIZE(t3));
    Py_DECREF(t3);

    return long_normalize(ret);

  fail:
    Py_XDECREF(ret);
    Py_XDECREF(ah);
    Py_XDECREF(al);
    Py_XDECREF(bh);
    Py_XDECREF(bl);
    return NULL;
}

Here, I won't explain the implementation of the Karatsuba algorithm1 separately. Interested friends can refer to the references at the end for specific details.

In ordinary cases, the time complexity of ordinary multiplication is O(n^2) (where n is the number of digits), while the time complexity of the K algorithm is O(3n^(log3)) ≈ O(3n^1.585). It seems that the K algorithm performs better than ordinary multiplication, so why doesn't Python use the K algorithm entirely?

It's simple: the advantage of the K algorithm actually only manifests when n is sufficiently large, and considering factors like memory access, when n is not large enough, the performance of the K algorithm will actually be worse than direct multiplication.

So let's take a look at the implementation of multiplication in Python.

static PyObject *
long_mul(PyLongObject *a, PyLongObject *b)
{
    PyLongObject *z;

    CHECK_BINOP(a, b);

    /* fast path for single-digit multiplication */
    if (Py_ABS(Py_SIZE(a)) <= 1 && Py_ABS(Py_SIZE(b)) <= 1) {
        stwodigits v = (stwodigits)(MEDIUM_VALUE(a)) * MEDIUM_VALUE(b);
        return PyLong_FromLongLong((long long)v);
    }

    z = k_mul(a, b);
    /* Negate if exactly one of the inputs is negative. */
    if (((Py_SIZE(a) ^ Py_SIZE(b)) < 0) && z) {
        _PyLong_Negate(&z);
        if (z == NULL)
            return NULL;
    }
    return (PyObject *)z;
}

Here we see that when both numbers are less than 2^30-1, Python will directly use ordinary multiplication and return; otherwise, it will use the K algorithm for calculation.

Now, let's look at the implementation of bitwise operations, taking right shift as an example.


static PyObject *
long_rshift(PyObject *a, PyObject *b)
{
    Py_ssize_t wordshift;
    digit remshift;

    CHECK_BINOP(a, b);

    if (Py_SIZE(b) < 0) {
        PyErr_SetString(PyExc_ValueError, "negative shift count");
        return NULL;
    }
    if (Py_SIZE(a) == 0) {
        return PyLong_FromLong(0);
    }
    if (divmod_shift(b, &wordshift, &remshift) < 0)
        return NULL;
    return long_rshift1((PyLongObject *)a, wordshift, remshift);
}

static PyObject *
long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
{
    PyLongObject *z = NULL;
    Py_ssize_t newsize, hishift, i, j;
    digit lomask, himask;

    if (Py_SIZE(a) < 0) {
        /* Right shifting negative numbers is harder */
        PyLongObject *a1, *a2;
        a1 = (PyLongObject *) long_invert(a);
        if (a1 == NULL)
            return NULL;
        a2 = (PyLongObject *) long_rshift1(a1, wordshift, remshift);
        Py_DECREF(a1);
        if (a2 == NULL)
            return NULL;
        z = (PyLongObject *) long_invert(a2);
        Py_DECREF(a2);
    }
    else {
        newsize = Py_SIZE(a) - wordshift;
        if (newsize <= 0)
            return PyLong_FromLong(0);
        hishift = PyLong_SHIFT - remshift;
        lomask = ((digit)1 << hishift) - 1;
        himask = PyLong_MASK ^ lomask;
        z = _PyLong_New(newsize);
        if (z == NULL)
            return NULL;
        for (i = 0, j = wordshift; i < newsize; i++, j++) {
            z->ob_digit[i] = (a->ob_digit[j] >> remshift) & lomask;
            if (i+1 < newsize)
                z->ob_digit[i] |= (a->ob_digit[j+1] << hishift) & himask;
        }
        z = maybe_small_long(long_normalize(z));
    }
    return (PyObject *)z;
}

Here we can see that in cases where both sides are small numbers, the bitwise shifting algorithm will involve more memory allocation and other operations than ordinary multiplication. This also answers the question we raised at the beginning: "Why is multiplication sometimes faster than bitwise operations?"

Conclusion#

This article is about to conclude. Through this analysis, we can gain some interesting but also obscure knowledge. In fact, the result we see now is a specific design that Python has made for common and high-frequency operations. This also reminds us that Python has its own built-in design philosophy for many operations, and experiences from other languages may not be reusable in daily use.

That's about it; I can only barely write filler articles to survive (escape).

Reference#

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.