seaborn heatmap でx,y 軸のラベルが出したかったお話

sns_heatmap GPU
sns_heatmap labels

pythonのnumpyを使って作成した二次元の表をseaborn のheatmapを使ってヒートマップ化したときに、

matplotlib.pyplotのxlabel({x軸のラベル名})、ylabel({y軸のラベル名})みたいに軸の属性を表記したかったけど、そのやり方が少しだけ詰まったのでメモ。(ちなみに、heatmapの引数のxticklabelsやyticklabelsは少し用途が違う。)

状況

画像中はイメージだが、csvなどからピポッドテーブルをつくってseabornをつかってヒートマップをつくったら自動的にx,y軸の属性名、属性値が割り振られる。ここで、属性名がmonth,year、属性値をそれぞれJanuary,February,…, や1949,1950,…のこととする。

numpyとかで2次元の混合行列(Confusion Matrix)などを出力して画像化したいときにこれを作るのがすこし詰まったという状況。

ちなみに、上の画像のために用意したデータセットは

import seaborn as sns
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers") 

で用意しました。

解決策

sns.heatmapAPIリファレンスをみてみると

Returns : ax : matplotlib Axes

Axes object with the heatmap.

https://seaborn.pydata.org/generated/seaborn.heatmap.html#seaborn.heatmap

ということで、裏で動いてるのはすなわちmatplotlibであるということなので、matplotlib.pyplotxlabelylabelで解決するのではないかと思って試したらうまく行った。

実際に動かしているコードを下に載せる。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
flights = sns.load_dataset("flights")
 
# ピボットを生成
flights = flights.pivot("month", "year", "passengers")
#データテーブルから2次元のnumpy配列に変換
np_flight=np.array(flights.values) 
#配列をヒートマップ化
sns.heatmap(np_flight)
#*以下2行がポイント*  X,Y軸ラベルを追加
plt.xlabel("xxxxxxxx")
plt.ylabel("yyyyyyyy")
#グラフをはみ出さないようにして画面に出力
plt.tight_layout()
plt.show()
#もし保存したいときはplt.show() をコメントアウトなどして、以下のコメントを外してね
#plt.savefig("flight_heatmap.png")

出力結果は以下のようになる

また、x,y軸の0~11ってとこになにか文字を入れたい場合はsns.heatmap(xticlabels=[X軸の方のリスト],yticlabels=[Y軸の方のリスト])というふうに引数にリストを渡して表示させる方法があります。

コード全体は以下の通り(14,15行目と18行目が対応点です。)

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
flights = sns.load_dataset("flights")
 
# ピボットを生成
flights = flights.pivot("month", "year", "passengers")
#データテーブルから2次元のnumpy配列に変換
np_flight=np.array(flights.values) 
#難しい感じで書いてますが
#[X_a,X_b,...,]みたいなリストを作ってるだけです
xtics=["X_"+chr(i) for i in range(97,97+12)]
ytics=["Y_"+chr(i) for i in range(97,97+12)]
#配列をヒートマップ化 + XYのカラム名を表示
sns.heatmap(np_flight,xticklabels=xtics,yticklabels=ytics)
#*以下2行がポイント*  X,Y軸ラベルを追加
plt.xlabel("xxxxxxxx")
plt.ylabel("yyyyyyyy")
#グラフをはみ出さないようにして画面に出力
plt.tight_layout()
plt.show()
#もし保存したいときはplt.show() をコメントアウトなどして、以下のコメントを外してね
#plt.savefig("./Pictures/flight_heatmap.png")

出力結果は以下の通り

これで、自分の思う通りに出力できました。ちなみに、色とかも変えれますのでお試しあれ。

結論

sns.seabornのオプションだけでは思うように出力できませんでした。

plt.xlabel(),plt.ylabel()を使って対応しなくてはならなかった。

追伸:他にやり方があればぜひおしえてください

GPULinux Ubuntunumpytensorflowサーバー機械学習深層学習
スポンサーリンク
Toufuをフォローする

コメント

  1. Hi there, of course this article is truly good and I have learned lot of things from it about blogging.

    thanks.

  2. buy CBD oil より:

    I am really happy to read this web site posts which includes plenty of useful facts, thanks for providing such data.

タイトルとURLをコピーしました