【NumPy入門 np.sort】配列を昇順にソートする方法について学ぼう

こんにちは、インストラクターのフクロウです!

要素を小さい順に並べる(ソートする)アルゴリズムはたくさんありますが、それらの実装はちょっと大変です。

NumPyではソートアルゴリズムが簡単に使える関数、np.sortが用意されています。

np.sortを使えば、クイックソートマージソートなどのソートアルゴリズムを簡単に切り替えることができるので、アルゴリズムを気にせずにソートができます。

この記事でnp.sortの使い方を覚えて、是非使ってみてください!

目次

np.sortの使い方

※この記事のコードは、jupyter notebookjuputer labを使って書かれています。
コードを試すときは是非これらを使ってみてください。

# コード In [1]:
import numpy as np

一次元配列のソート

np.sort関数の基本的な使い方は、Pythonのsortedと同様です。

np.sortを適用すると、昇順にソートされた配列が返ってきます。

# コード In [2]:
a = np.arange(0,10)
np.random.shuffle(a)
print("original a:n", a,"n")

sorted_a = np.sort(a)
print("sorted a:n", sorted_a,"n")
# 出力結果 [2]:
original a:
 [1 5 6 7 2 8 9 4 0 3] 

sorted a:
 [0 1 2 3 4 5 6 7 8 9] 

逆順にソート

reversedのような逆順にソートする関数はありませんが、numpy.ndarrayを逆にする方法はあります。

●スライス記法で逆順にする

配列aを逆順にしたいなら→ a[::-1]

これをsortと併用することで逆順にソートした結果を得れます。試してみましょう。

# コード In [3]:
reversed_a = np.sort(a)[::-1]
print("reversed a:n", reversed_a)
# 出力結果 [3]:
reversed a:
 [9 8 7 6 5 4 3 2 1 0]

多次元配列のソート

np.sumなどと同様に、axisパラメータを使ってソートする際に従う軸を指定できます。

# コード In [4]:
b = np.reshape(a, (2,5))
print("original b:n", b,"n")

# 最後の軸に従ってソート(axis=-1, つまりここではaxis=1)
sorted_b = np.sort(b)
print("sorted b:n", sorted_b,"n")

# 列ごとにソート
sorted_b_axis0 = np.sort(b, axis=0)
print("sorted b (axis=0):n", sorted_b_axis0,"n")

# 行ごとにソート
sorted_b_axis1 = np.sort(b, axis=1)
print("sorted b (axis=1):n", sorted_b_axis1,"n")
# 出力結果 [4]:
original b:
 [[1 5 6 7 2]
 [8 9 4 0 3]] 

sorted b:
 [[1 2 5 6 7]
 [0 3 4 8 9]] 

sorted b (axis=0):
 [[1 5 4 0 2]
 [8 9 6 7 3]] 

sorted b (axis=1):
 [[1 2 5 6 7]
 [0 3 4 8 9]] 

ソートアルゴリズムの変更

np.sortにはkindパラメータがあり、ここでは‘quicksort’, ‘mergesort’, ‘heapsort’, ‘stable’の四種類のソートアルゴリズムを指定できます。

デフォルトではkind=‘quicksort’となっているので、クイックソート以外を使いたい場合はこのパラメータを変更しましょう。

# コード In [5]:
x = np.linspace(-9,10,10000)

# ソートアルゴリズムによって実行速度が違います。
%timeit np.sort(x, kind="quicksort")
%timeit np.sort(x, kind="mergesort")
%timeit np.sort(x, kind="heapsort")
%timeit np.sort(x, kind="stable")
# 出力結果 [5]:
67.4 µs ± 136 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
78 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
429 µs ± 240 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
78.1 µs ± 176 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

任意の指標を使ってソート

pythonのsortメソッドやsorted関数にはkeyオプションがありました。

これはkeyに指定した関数に従ってソートするというものです。

これに近い方法をnumpyで行うにはどうすればいいでしょうか。

絶対値を使って昇順ソートをしたいとしましょう。

numpy.ndarrayの全ての要素に対して適用できる関数でソートの基準を再現できるなら、以下の流れでソートできます。

  1. ソートしたい配列cに、ソートの基準にしたい関数を適用
  2. 1でできた配列をargsort
  3. 2でできた配列を使ってcを並び替える
# コード In [6]:
c = np.linspace(-9,10,10)
print("original c:n", c,"n")

c_key = np.abs(c)
print("元の配列にソートの基準にしたい関数を適用:n", c_key,"n")

c_index = np.argsort(c_key)
print("c_keyをargsort:n", c_index, "n")

sorted_c = c[c_index]
print("c_indexを元にソートしたい配列cを並び替える:n", sorted_c,"n")
# 出力結果 [6]:
original c:
 [-9.         -6.88888889 -4.77777778 -2.66666667 -0.55555556  1.55555556
  3.66666667  5.77777778  7.88888889 10.        ] 

元の配列にソートの基準にしたい関数を適用:
 [ 9.          6.88888889  4.77777778  2.66666667  0.55555556  1.55555556
  3.66666667  5.77777778  7.88888889 10.        ] 

c_keyをargsort:
 [4 5 3 6 2 7 1 8 0 9] 

c_indexを元にソートしたい配列cを並び替える:
 [-0.55555556  1.55555556 -2.66666667  3.66666667 -4.77777778  5.77777778
 -6.88888889  7.88888889 -9.         10.        ] 

一つの関数にまとめると以下のような感じです。

# コード In [7]:
def keysort(arr, key):
    arr_key = key(arr)
    arr_index = np.argsort(arr_key)
    return arr[arr_index]

sorted_abs = keysort(c, np.abs)
print("絶対値を元にソート:n", sorted_abs, "n")
# 出力結果 [7]:
絶対値を元にソート:
 [-0.55555556  1.55555556 -2.66666667  3.66666667 -4.77777778  5.77777778
 -6.88888889  7.88888889 -9.         10.        ] 

どうにかnumpyの関数だけで目的のソート方法が再現できるならば、このような方法も使えます。

どうしても難しい場合は、pythonのsorted関数を使ってしまいましょう。

# コード In [8]:
sorted(c, key=abs)
# 出力結果 Out [8]:
[-0.5555555555555554,
 1.5555555555555554,
 -2.666666666666666,
 3.666666666666668,
 -4.777777777777778,
 5.777777777777779,
 -6.888888888888889,
 7.888888888888889,
 -9.0,
 10.0]

ただし、pythonのsortedはnumpyだけで実装したkeysortより遅いので注意です。

# コード In [9]:
x = np.linspace(-9,10,10000)

%timeit keysort(x, np.abs)
%timeit sorted(x, key=abs)
# 出力結果 [9]:
483 µs ± 225 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.13 ms ± 12.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

まとめ

この記事では、np.sortを中心に、np.ndarrayをソートする方法を解説しました。

ソートアルゴリズムは実装によっては速度が大幅に変わることのあるものなので、NumPyに組み込まれているnp.sortを使って楽に高速な関数を使うのがおすすめです。

様々なところで必要になるソート、簡単ですがしっかり覚えて使いこなしましょう。

この記事を書いた人

第一言語はPythonです。
皆さんRustやりましょう。

目次