- Matplotlibを用いてGUI上でフィッティングを実施。
- 選択範囲を選ぶと結果をリアルタイムにプロットできます。
- 意外とweb上に見当たらなかったので作りました。
コード例
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.widgets import SpanSelector
from lmfit import Model
def gaussian(x, amp, cen, wid):
return (amp / (np.sqrt(2*np.pi) * wid)) * np.exp(-(x-cen)**2 / (2*wid**2))
def fit_gaussian(x, data):
mod = Model(gaussian)
pars = mod.make_params(amp=1, cen=1, wid=1)
result = mod.fit(data, pars, x=x)
return result
class Fit:
def __init__(self, ax, df):
self.ax = ax or plt.gca()
self.df = df
def fitting(self, x1, x2):
xx = self.df.query(f'{x1}<x<{x2}').x
yy= self.df.query(f'{x1}<x<{x2}').y
result = fit_gaussian(xx, yy)
print("-------------Fitting result------------------")
pars = list(result.best_values.values())
print(pars)
fitted.set_data(self.df.x, gaussian(self.df.x, *pars) )
residual.set_data(self.df.x, self.df.y - gaussian(self.df.x, *pars) )
fig.canvas.draw()
fig.canvas.flush_events()
def select_callback(self, x1, x2):
self.fitting(x1, x2)
# デモデータ
data = np.random.randn(10000)
histo,bins = np.histogram(data,range=(-5,5),bins=100,density=True)
x=bins[1:]
df = pd.DataFrame(data=np.stack([x,histo]).T, columns=['x', 'y'])
# プロット
fig = plt.figure(constrained_layout=True)
ax = fig.add_subplot(111)
ax.plot(df.x, df.y, ".", label="data")
ax.set_xlabel("x")
ax.set_ylabel("y (a.u.)")
# あとで更新する
fitted, = ax.plot([],[], c="r", label = "fitted")
residual, = ax.plot([], [], drawstyle = "steps-post", label = "residual")
ax.legend()
fit = Fit(ax, df)
# x軸を選ぶ
span = SpanSelector(
ax,
fit.select_callback,
"horizontal",
useblit=True,
props=dict(alpha=0.2, facecolor="tab:blue"),
interactive=True,
drag_from_anywhere=True
)
plt.show()
ポイント解説
ライブラリ
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.widgets import SpanSelector
from lmfit import Model
def gaussian(x, amp, cen, wid):
return (amp / (np.sqrt(2*np.pi) * wid)) * np.exp(-(x-cen)**2 / (2*wid**2))
def fit_gaussian(x, data):
mod = Model(gaussian)
pars = mod.make_params(amp=1, cen=1, wid=1)
result = mod.fit(data, pars, x=x)
return result
「from matplotlib.widgets import SpanSelector」がmatplotlibのGUI操作に必要なモジュールです。
「from lmfit import Model」がフィッティングに使用するモジュールです。
直後にガウシアン関数とその関数をもとにフィッティングする関数を定義しています。
更新用のplt
# あとで更新する
fitted, = ax.plot([], [], c="r", label = "fitted")
フィッティングによりデータを後から更新する(set_data)ために、ここでは先にオブジェクトのみを渡しています。
MatplotlibのGUI用ウィジェット
span = SpanSelector(
ax,
fit.select_callback,
"horizontal",
useblit=True,
props=dict(alpha=0.2, facecolor="tab:blue"),
interactive=True,
drag_from_anywhere=True
)
SpanSelectorはマウス操作(クリック&ドラッグ・リリース)でx軸を選択するためのウィジェットです。
引数の「fit.select_callback」はマウス操作に対応して実施される関数(Callback関数)を指定しています。
この関数は上部のクラス「Fit」内で定義されています。
また「interactive=True」で選択した領域を操作できるようにしています。
SpanSelectorの詳細は公式ページで確認できます。
フィッティングとCallback関数
class Fit:
def __init__(self, ax, df):
self.ax = ax or plt.gca()
self.df = df
def fitting(self, x1, x2):
xx = self.df.query(f'{x1}<x<{x2}').x
yy= self.df.query(f'{x1}<x<{x2}').y
result = fit_gaussian(xx, yy)
print("-------------Fitting result------------------")
pars = list(result.best_values.values())
print(pars)
fitted.set_data(self.df.x, gaussian(self.df.x, *pars) )
residual.set_data(self.df.x, self.df.y - gaussian(self.df.x, *pars) )
fig.canvas.draw()
fig.canvas.flush_events()
def select_callback(self, x1, x2):
self.fitting(x1, x2)
ここではフィッティングをマウス操作で実施できるようにしています。
init(self, ax, df)
フィッティングに必要なグラフ情報(ax)と元データ(df)はインスタンス作成の引数に設定
select_callback(self, x1, x2)
上記のSpanSelectorによって呼ばれる関数です。選択領域端のx1, x2を使用できます。
ここではx1,x2を引数にして直ぐにself.fittingを呼びます。
fitting(self, x1, x2)
padas dfのqueryで選択範囲のみのデータを切り出します。
result = fit_gaussian(xx, yy)でフィッティング実施し、resultから必要なデータを取り出します。
result.best_valuesはフィッティング変数(ここではgaussian関数で定義したamp, cen, wid)の結果が入っています。
values()で順番に数字を取り出して、list()でリスト化しています。
plotする際には「アスタリスク*」を変数名parsにつけてgasussian(x,amp,cen,wid)へ代入しています。
fitted.set_dat
は先に渡したプロットオブジェクトのデータを更新しています。
今後作りたいもの
今回は使用しませんでしたが、「fig.canvas.mpl_connect」を使うと、matplotlibの操作にキーボード操作などを付与できます。例えばキー入力してフィット関数を変えるなど自由度が増えますね。
2022 12/22追記:fig.canvas.mpl_connectを使う例を作成しました。
2023 1/7追記:より実践的なGUIを紹介しました。
まとめ
matplotlibのウィジェットとlmfitを用いてGUI操作で範囲を指定してフィッティングする方法を解説しました。
コメント