【NumPy入門 np.argsort】配列をソートしてインデックスを返す関数

こんにちは、インストラクターのフクロウです!

NumPyには、配列をソートする関数のnp.sortがあります。

これに対して、ソートした配列の要素のインデックスを返す関数np.argsortです。

この記事では、np.argsort関数の基本的な使い方から、多次元配列に対する使い方までをご説明します!

目次

np.argsortで一次元配列をソートする

この記事で使っているf-stringsという機能は比較的新しい書き方なので、気になった方は以下の記事で確認して下さいね。

さて、まずは簡単な使い方を確認しましょう。

使い方はnp.sortと同様です!

import numpy as np
a = np.random.random(5)

print(
    f"""Original a
    {a}
    
    np.sort(a)
    {np.sort(a)}
    
    np.argsort(a)
    {np.argsort(a)}
    """
)

[出力結果]

Original a
    [0.4142068  0.5381178  0.53869381 0.16719389 0.14693401]
    
    np.sort(a)
    [0.14693401 0.16719389 0.4142068  0.5381178  0.53869381]
    
    np.argsort(a)
    [4 3 0 1 2]

np.sort関数の出力値を確認したら、各要素が元の配列のどのIndexだったかを確認してみてください!

しっかりnp.argsortの出力と同じになっているはずです。

np.argsortで多次元配列をソートする

基本的な使い方

さて、argsortは出力がindexになっているだけで、多次元配列でも使い方はnp.sortと同じです。

どのような動作をするか見てみましょう。

b = np.random.random((4,5))

print(
    f"""Original b
    {b}
    
    np.sort(b)
    {np.sort(b)}
    
    np.argsort(b)
    {np.argsort(b)}
    """
)

[出力結果]

Original b
    [[0.97935047 0.90121107 0.80337142 0.18207252 0.23487007]
 [0.70830563 0.18397233 0.06510502 0.12959257 0.54751296]
 [0.40600083 0.12071886 0.03439174 0.82277469 0.3728914 ]
 [0.17822189 0.33712262 0.487275   0.89956249 0.95184833]]
    
    np.sort(b)
    [[0.18207252 0.23487007 0.80337142 0.90121107 0.97935047]
 [0.06510502 0.12959257 0.18397233 0.54751296 0.70830563]
 [0.03439174 0.12071886 0.3728914  0.40600083 0.82277469]
 [0.17822189 0.33712262 0.487275   0.89956249 0.95184833]]
    
    np.argsort(b)
    [[3 4 2 1 0]
 [2 3 1 4 0]
 [2 1 4 0 3]
 [0 1 2 3 4]]

ソートしたい配列だけを引数にすると、各行の中でソートを行っていることがわかりますね。

print(
    f"""Original b
    {b}
    
    np.sort(b)
    {np.sort(b, axis=0)}
    
    np.argsort(b)
    {np.argsort(b, axis=0)}
    """
)

[出力結果]

Original b
    [[0.97935047 0.90121107 0.80337142 0.18207252 0.23487007]
 [0.70830563 0.18397233 0.06510502 0.12959257 0.54751296]
 [0.40600083 0.12071886 0.03439174 0.82277469 0.3728914 ]
 [0.17822189 0.33712262 0.487275   0.89956249 0.95184833]]
    
    np.sort(b)
    [[0.17822189 0.12071886 0.03439174 0.12959257 0.23487007]
 [0.40600083 0.18397233 0.06510502 0.18207252 0.3728914 ]
 [0.70830563 0.33712262 0.487275   0.82277469 0.54751296]
 [0.97935047 0.90121107 0.80337142 0.89956249 0.95184833]]
    
    np.argsort(b)
    [[3 2 2 1 0]
 [2 1 1 0 2]
 [1 3 3 2 1]
 [0 0 0 3 3]]

axis=1

axis=1では、行ごとに見て要素をソートすることができます。

print(
    f"""Original b
    {b}
    
    np.sort(b)
    {np.sort(b, axis=1)}
    
    np.argsort(b)
    {np.argsort(b, axis=1)}
    """
)

[出力結果]

Original b
    [[0.97657721 0.36010893 0.14086801]
 [0.09089705 0.92285532 0.27189188]]
    
    np.sort(b)
    [[0.14086801 0.36010893 0.97657721]
 [0.09089705 0.27189188 0.92285532]]
    
    np.argsort(b)
    [[2 1 0]
 [0 2 1]]

これらを使うことで、より柔軟なソートができますね。

まとめ

この記事では、配列のソートしたインデックスを返す関数であるnp.argsortを紹介しました。

最大値や最小値のみを取り出す機能はありましたが、np.argsortを使えば

  • 上位N個の要素を取り出す
  • 下位N個の要素を取り出す

などの操作も簡単にできます。

ぜひ覚えて便利に使いこなしてください!

この記事を書いた人

第一言語はPythonです。
皆さんRustやりましょう。

目次