`
yangdong
  • 浏览: 65000 次
  • 性别: Icon_minigender_1
  • 来自: 杭州
社区版块
存档分类
最新评论

TimSort 中的核心过程

 
阅读更多
    TimSort 是 Python 中 list.sort 的默认实现。Java 7 也将非原始类型列表的排序实现替换成了 TimSort。网上关于 TimSort 是什么,性能特点分析的文章不少,但是介绍它的具体实现步骤的文章很少。这里有一篇:Understanding timsort, Part 1: Adaptive Mergesort,用 C 作为示例代码。

基于这个文章的介绍,我用 python 实现一遍 TimSort,并说一下其中的关键步骤。因为原文只讲解了 TimSort 中最基本最重要的部分,所以本文也没有超过这个范围。本文不是对 TimSort 的分析,只是介绍一下其基本实现。

TimSort 概览
    TimSort 是一个归并排序做了大量优化的版本。对归并排序排在已经反向排好序的输入时表现O(n^2)的特点做了特别优化。对已经正向排好序的输入减少回溯。对两种情况混合(一会升序,一会降序)的输入处理比较好。

TimSort 核心过程
    假定,我们的 TimSort 是进行升序排序。TimSort 为了减少对升序部分的回溯和对降序部分的性能倒退,将输入按其升序和降序特点进行了分区。排序的输入的单位不是一个个单独的数字了,而一个个的分区。其中每一个分区我们叫一个“run“。针对这个 run 序列,每次我们拿一个 run 出来进行归并。每次归并会将两个 runs 合并成一个 run。归并的结果保存到 "run_stack" 上。如果我们觉得有必要归并了,那么进行归并,直到消耗掉所有的 runs。这时将 run_stack 上剩余的 runs 归并到只剩一个 run 为止。这时这个仅剩的 run 即为我们需要的排好序的结果。

def timsort(arr):
    arr = arr or []
    if len(arr) <= 0: return []
    runs = _partition_to_runs(arr)
    run_stack = []
    for run in runs:
        run_stack.append(run)
        while _should_merge(run_stack):
            _merge_stack(run_stack)
    while len(run_stack) > 1:
        _merge_stack(run_stack)
    return run_stack[0]


这里“觉得有必要”这句话很模糊,到底什么时候有必要后面会给出定义。

如何分区
    为了在已经按升序排好序的输入面前减少回溯,我们把输入当中已经有序的这些段分组,使得它们成为一个基本单元,这样我们就不必在这个基本单元内部浪费时间进行回溯了。比如[1, 2, 3, 2] 进行分区后就变成了 [[1, 2, 3], [2]]。

为了在已经按降序排好序的输入面前避免归并排序倒退成 O(n^2),我们把输入当中降序的部分翻转成升序,也作为一个单元。比如 [3, 2, 1, 3] 进行分区后就变成了 [[1, 2, 3], [3]]。

def _partition_to_runs(arr):
    partitioned_up_to = 0
    while partitioned_up_to < len(arr):
        if not len(arr) - partitioned_up_to:
            return
        if len(arr) - partitioned_up_to == 1:
            part = list(arr[-1:])
            partitioned_up_to += 1
            yield part
        else:
            if arr[partitioned_up_to] > arr[partitioned_up_to + 1]: # 这里必须是严格降序
                next_pos = _find_desc_boundary(arr, partitioned_up_to)
                _reverse(arr, partitioned_up_to, next_pos)
            else:
                next_pos = _find_asc_boundary(arr, partitioned_up_to)

            part = arr[partitioned_up_to:next_pos]
            partitioned_up_to = next_pos
            yield part

def _find_desc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] > arr[start+1]: # 这里必须是严格降序
        return _find_desc_boundary(arr, start + 1)
    else:
        return start + 1

def _reverse(arr, start=0, end=None):
    # 正常的翻转函数,实现省略

def _find_asc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] <= arr[start+1]:
        return _find_asc_boundary(arr, start + 1)
    else:
        return start + 1


这里注意降序的部分必须是“严格”降序才能进行翻转。因为 TimSort 的一个重要目标是保持稳定性(stability)。如果在 >= 的情况下进行翻转这个算法就不再是 stable sorting algorithm 了。

逆向分解
    传统的归并排序是通过递归,用函数栈把每次 "divide" 的结果保存下来的。divide 的最终结果是一个个的基本单元-单个数字。但是我们看到 TimSort 把这个过程反过来了。我们经过一次分区,已经拿到了了基本单元列表,只不过这次基本单元是一串数字。所以我们只能自己手工将将基本单元列表进行合并。

如何合并
    那么何时进行合并?合并的策略是要在 "run_stack" 上维护一个不变式。当这个不变式被打破时即进行合并。传统的归并排序通过二分法可以保证函数栈的深度为 log(n)。我们也模拟这个策略,也让 run_stack 的长度不超过 log(n)。假如 runN 先入栈,runN+1 紧随其后入栈。那么就要求 runN 的长度要是 runN+1 长度的 2 倍。所以归并的条件是:如果 runN 的长度 < (runN+1 的长度 * 2) 即进行归并。

# 因为我们每次新添 run 进入 run_stack 时都判断是否需要归并,
# 并且在每次归并之后还要进一步确保 run_stack 是满足不变式的,
# 所以这里只判断栈头的两个 run 就够了。
def _should_merge(run_stack):
    if len(run_stack) < 2:
        return False
    return len(run_stack[-2]) < 2*len(run_stack[-1])

def _merge(ls1, ls2):
    # 正常的归并函数,实现省略

def _merge_stack(run_stack):
    head = run_stack.pop()
    next = run_stack.pop()
    new_run = _merge(next, head)
    run_stack.append(new_run)


跟分区的情况类似,这里在归并的时候也要用 stable merge。

插入排序优化
    到上面的步骤为止,程序已经可以正确地排序了。但是我们知道插入排序在输入元素数小于一个阀值的时候相比其它排序会更快,所以很多排序算法在 divide 这一步进行到只剩不到这个阀值个数的元素的时候会改用插入排序(比如 JDK6 的快排,参考这里),所以我们也要做这个优化。

在分区的时候,如果我们观察到新产生出来的 run 的长度小于适用于插入排序的阀值,我们就用插入排序把这个 run 的长度扩充到这个阀值。

def _partition_to_runs(arr):
    partitioned_up_to = 0
    while partitioned_up_to < len(arr):
        if not len(arr) - partitioned_up_to:
            return
        if len(arr) - partitioned_up_to == 1:
            part = list(arr[-1:])
            partitioned_up_to += 1
            yield part
        else:
            if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
                next_pos = _find_desc_boundary(arr, partitioned_up_to)
                _reverse(arr, partitioned_up_to, next_pos)
            else:
                next_pos = _find_asc_boundary(arr, partitioned_up_to)

            # 只加了这一句话
            next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)

            part = arr[partitioned_up_to:next_pos]
            partitioned_up_to = next_pos
            yield part

def _insertion_sort(arr, start, end):
    # 标准插入排序实现

def _do_insertion_sort_optimization(arr, start, end):
    length = end - start
    if length < INSERTION_SORT_THRESHOLD:
        end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
        _insertion_sort(arr, start, end)
    return end


这里我们只加一句话就够了。剩余的就是标准的插入排序实现。

与原文代码的差异
    TimSort 最多使用 O(n) 临时内存空间。由于原文是 C 的代码,为了减少 malloc 的次数而一次性分配了 O(n) 的数组空间。我们这里因为是用 python,也这么做会显得很怪异。所以内存是在每次归并的时候一点点分配的。

TimSort 的实现逻辑上可以看成分区和归并两部分。但由于 C 不支持协程,而 python 通过 generator 部分支持协程。所以为了提高可读性,分区的部分我是用 generator 的方式做的。在代码上与归并的部分完全分离。而原文为了达到 lazy 的目的,是一边分区一边归并的。

完整的实现和测试代码
# -*- coding: utf-8 -*-
import functools
from unittest import TestCase

INSERTION_SORT_THRESHOLD = 6

def _find_desc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] > arr[start+1]:
        return _find_desc_boundary(arr, start + 1)
    else:
        return start + 1

def _reverse(arr, start=0, end=None):
    if end is None:
        end = len(arr)
    for i in range(start, start + (end-start)//2):
        opposite = end - i - 1
        arr[i], arr[opposite] = arr[opposite], arr[i]

def _find_asc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] <= arr[start+1]:
        return _find_asc_boundary(arr, start + 1)
    else:
        return start + 1

def _insertion_sort(arr, start, end):
    if end - start <= 1:
        return
    for i in range(start, end):
        v = arr[i]
        j = i - 1
        while j>=0 and arr[j] > v:
            arr[j+1] = arr[j]
            j -= 1
        arr[j+1] = v

def _do_insertion_sort_optimization(arr, start, end):
    length = end - start
    if length < INSERTION_SORT_THRESHOLD:
        end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
        _insertion_sort(arr, start, end)
    return end

def _partition_to_runs(arr):
    partitioned_up_to = 0
    while partitioned_up_to < len(arr):
        if not len(arr) - partitioned_up_to:
            return
        if len(arr) - partitioned_up_to == 1:
            part = list(arr[-1:])
            partitioned_up_to += 1
            yield part
        else:
            if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
                next_pos = _find_desc_boundary(arr, partitioned_up_to)
                _reverse(arr, partitioned_up_to, next_pos)
            else:
                next_pos = _find_asc_boundary(arr, partitioned_up_to)

            next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)

            part = arr[partitioned_up_to:next_pos]
            partitioned_up_to = next_pos
            yield part

def _should_merge(run_stack):
    if len(run_stack) < 2:
        return False
    return len(run_stack[-2]) < 2*len(run_stack[-1])

def _merge(ls1, ls2, merge_storage=None):
    ret = merge_storage or []
    i1 = 0
    i2 = 0
    while i1 < len(ls1) and i2 < len(ls2):
        a = ls1[i1]
        b = ls2[i2]
        if a <= b:
            ret.append(a)
            i1 += 1
        else:
            ret.append(b)
            i2 += 1
    ret += ls1[i1:]
    ret += ls2[i2:]
    return ret

def _merge_stack(run_stack, merge_storage=None):
    head = run_stack.pop()
    next = run_stack.pop()
    new_run = _merge(next, head, merge_storage=merge_storage)
    run_stack.append(new_run)

def timsort(arr):
    arr = arr or []
    if len(arr) <= 0: return []
    runs = _partition_to_runs(arr)
    run_stack = []
    for run in runs:
        run_stack.append(run)
        while _should_merge(run_stack):
            _merge_stack(run_stack)
    while len(run_stack) > 1:
        _merge_stack(run_stack)
    return run_stack[0]

class Test(TestCase):
    class Elem:
        seq_no = 0
        def __init__(self, n):
            Elem = Test.Elem
            self.n = n
            self.seq_no = Elem.seq_no
            Elem.seq_no += 1

        def __lt__(self, other):
            return self.n < other.n

        def __str__(self):
            return "E" + str(self.n) + "S" + str(self.seq_no)
    Elem = functools.total_ordering(Elem)

    def setUp(self):
        Test.Elem.seq_no = 0

    def test_reverse(self):
        arr = [3, 2, 1, 4, 7, 5, 6]
        _reverse(arr)
        self.assertEquals(arr, [6, 5, 7, 4, 1, 2, 3])

        arr = [3, 2, 1]
        _reverse(arr)
        self.assertEquals(arr, [1, 2, 3])

    def test_find_asc_boundary(self):
        arr = [1, 2, 3, 3, 2]
        self.assertEqual(_find_asc_boundary(arr, 0), 4)

        arr = [1, 2, 3, 3]
        self.assertEqual(_find_asc_boundary(arr, 0), 4)

    def test_find_desc_boundary(self):
        arr = [3, 2, 1]
        self.assertEqual(_find_desc_boundary(arr, 0), 3)

        arr = [3, 2, 1, 1]
        self.assertEqual(_find_desc_boundary(arr, 0), 3)

    def test_merge_stack(self):
        arr1 = [1, 2, 3]
        arr2 = [2, 3, 4]
        stack = [arr1, arr2]
        _merge_stack(stack)
        self.assertEqual(stack, [[1, 2, 2, 3, 3, 4]])

    def test_merge_stability(self):
        Elem = Test.Elem
        arr1 = map(lambda e: Elem(e), [1, 2, 3])
        arr2 = map(lambda e: Elem(e), [2, 3, 4])
        stack = [arr1, arr2]
        _merge_stack(stack)
        self.assertEqual(map(lambda lst: map(str, lst), stack), [['E1S0', 'E2S1', 'E2S3', 'E3S2', 'E3S4', 'E4S5']])

    def test_timsort(self):
        Elem = Test.Elem
        arr = map(lambda e: Elem(e), [3, 1, 2, 2, 7, 5])
        ret = timsort(arr)
        self.assertEquals(map(str, ret), ['E1S1', 'E2S2', 'E2S3', 'E3S0', 'E5S5', 'E7S4'])

        self.assertEqual(timsort([]), [])
        self.assertEqual(timsort(None), [])
分享到:
评论

相关推荐

    TimSort:JSES中的TimSort实现

    TimSort在JS / ES中的实现。 专为教育目的而设计。 什么是TimSort 是一种,是从和派生而来的,旨在对多种现实世界数据表现良好。 它使用了Peter McIlroy的“乐观排序和信息理论的复杂性”技术,该技术在1993年1月...

    timsort 算法

    jdk中collections包中用的排序算法,算法以发明者tim命名

    cpp-TimSort:timsort的C ++实现

    cpp-TimSort:timsort的C ++实现

    java-timsort-bug:如何破坏 TimSort 以及如何修复它

    时间排序错误如何破坏 TimSort 以及如何修复它

    Timsort For Fortran:通过 ISO_C_BINDING 在 Fortran 中使用 timsort.c-开源

    我发现我需要在 Fortran 中使用 Timsort,但找不到 Fortran 实现。 我确实找到了 C 和 C++ 中的实现。 我的第一个努力是(痛苦地)将 C++ 代码翻译成 Fortran。 然后,我决定从 C 代码翻译可能更容易一些。 然后我...

    深入探究TimSort对归并排序算法的优化及Java实现

    主要介绍了TimSort归并排序的优化及Java实现,TimSort 是一个归并排序做了大量优化的版本,需要的朋友可以参考下

    TimSort-最快的排序算法-Python实现

    python实现最快的排序算法timsort实例,附有大量注释,易懂【手动狗头】。

    algorithms-and-data-structures:用 C# 编写的各种算法和数据结构的集合

    用C#制作的各种算法和数据结构的集合目前拥有:算法计算几何Convex hull — 使用 Graham 扫描找到包含集合中所有点的最小多边形图表最小割——计算最轻的一组边,如果它们被删除,将把图分成两个部分最小生成树 — ...

    sort:对“模板” C中的例程实现进行排序

    排序概述sort.h是C语言中大量排序算法的实现,具有包含时间提供的用户定义类型。 这意味着您不必支付使用标准库例程的函数调用开销。 这也为我们提供了高级语言泛型的功能。 另外,您不必链接库:此排序库的全部包含...

    leetcode凑硬币-Arithmatic:算术

    概要的讲解timsort的实现以及timsort的bugs,因为是视频,所以相比论文我觉得更快看得懂,没字幕,听不懂怎么办,没事,演讲者有一个文章重新梳理视频内容 2,Tim peters自己写的论文 二维“有序数组查找” —...

    sort-test-js

    排序比较的图形化表示JS中的TimSort(本机),QuickSort(本机),WikiSort(自定义),lodash.sortBy 发射 npm i npm run test npm start 要检查TimSort,请使用V8&gt; = 7版本(例如12 node.js) 要测试QuickSort...

    C语言演示对归并排序算法的优化实现

    子数组的排序同样采用这样的方法排序,这个过程是递归的。 下面是示例代码: #include timsort.h #include #include // 将两个长度分别为l1, l2的已排序数组p1, p2合并为一个 // 已排序的目标数组。 void merge...

    SortPerf.jl:Julia模块测试排序算法的性能

    该模块的目的是测试Julia中不同种类(及相关)算法的性能。 有关版本0.3.0-prerelease + 125的示例输出,请参见 。 运行: std_sort_tests(;sort_algs=SortPerf.sort_algs, # [InsertionSort, HeapSort, ...

    java-server-interview-questions:java服务端面试题整理

    TimSort算法就是找到已经排好序数据的子序列,然后对剩余部分排序,然后合并起来. foreach和while的区别(编译之后) 线程池的种类,区别和使用场景 分析线程池的实现原理和线程的调度过程 线程池如何调优 线程池的...

    super-simple-sortr

    这是我从事的一个小项目,目的是尝试增强我在不同分类算法中的知识。 制作真的很有趣,希望您喜欢。 该应用程序包含12种不同的排序算法,这些算法可视化为: 气泡排序 BogoSort 鸡尾酒摇床排序 梳理排序 计数...

    java-GenericSort:Java中使用泛型实现的各种排序算法

    通用排序Java中使用泛型实现的各种排序算法注意:这是一个有趣的项目。 如果您希望熟悉排序算法或Java... 如果要在实际项目中执行排序,建议您使用Java的内置Collections.sort(),它使用高效的TimSort。 请参阅 。

    barsort:一种非常稳定的javascript数字排序

    Barsort利用类似于“计数排序”的特殊算法,该算法用于将数组元素放入大小相似,大小相等的组中。 它在这里与经过调整的插入和合并排序以及边缘案例处理相结合,以创建快速的数字排序。 在各种可能的输入分布和大小...

    algOrd-compare:马德里TP de PAA的商业比较

    简介索泰trabalhoserãorealizadascomparações德desempenho恩特雷里奥斯OS algoritmos MergeSort , InsertionSort ē TimSort considerandoØ节奏EONÚMERO德comparaçõesfeitas对ordenar completamente OS ...

    Algorithms:我实现的所有算法均根据其目的分类

    演算法我实现的所有算法均对它们的用途进行了分类。内容您可以找到以下目录: 代数:与代数有关的算法(找到素数,最大公约数) ... SortingAlgorithms :一些排序算法和变体(快速排序,堆排序,合并排序,timsort)。

    simple-algorithms:算法评论

    简单算法 关于 简短的算法回顾使用Java,包括在最佳和最差情况下每种算法的步骤数,时间复杂度和运行时间。 演算法 快速排序 合并排序 堆排序 气泡排序 插入排序 选择排序 基数排序 Timsort

Global site tag (gtag.js) - Google Analytics