当サイトには広告・プロモーションが含まれています。

Matplotlib | GUIで範囲選択してフィッティング

この記事で分かること
  • Matplotlibを用いてGUI上でフィッティングを実施。
  • 選択範囲を選ぶと結果をリアルタイムにプロットできます。
  • 意外とweb上に見当たらなかったので作りました。
目次

コード例


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.widgets import SpanSelector
from lmfit import Model

def gaussian(x, amp, cen, wid):
    return (amp / (np.sqrt(2*np.pi) * wid)) * np.exp(-(x-cen)**2 / (2*wid**2))

def fit_gaussian(x, data):
    mod = Model(gaussian) 
    pars = mod.make_params(amp=1, cen=1, wid=1)
    result = mod.fit(data, pars, x=x)
    return result

class Fit:
    def __init__(self, ax, df):
        self.ax = ax or plt.gca()
        self.df = df

    def fitting(self, x1, x2):
        xx = self.df.query(f'{x1}<x<{x2}').x
        yy= self.df.query(f'{x1}<x<{x2}').y

        result = fit_gaussian(xx, yy)
        print("-------------Fitting result------------------")
        pars =  list(result.best_values.values())
        print(pars)
        
        fitted.set_data(self.df.x, gaussian(self.df.x, *pars) )
        residual.set_data(self.df.x, self.df.y - gaussian(self.df.x, *pars) )

        fig.canvas.draw()
        fig.canvas.flush_events()   

    def select_callback(self, x1, x2):
        self.fitting(x1, x2)

# デモデータ
data = np.random.randn(10000)
histo,bins = np.histogram(data,range=(-5,5),bins=100,density=True)
x=bins[1:]
df = pd.DataFrame(data=np.stack([x,histo]).T, columns=['x', 'y'])

# プロット
fig = plt.figure(constrained_layout=True)
ax = fig.add_subplot(111)
ax.plot(df.x, df.y, ".", label="data")
ax.set_xlabel("x")
ax.set_ylabel("y (a.u.)")
# あとで更新する
fitted, = ax.plot([],[], c="r", label = "fitted")
residual, = ax.plot([], [], drawstyle = "steps-post", label = "residual")
ax.legend()

fit = Fit(ax, df)

# x軸を選ぶ
span = SpanSelector(
    ax,
    fit.select_callback,
    "horizontal",
    useblit=True,
    props=dict(alpha=0.2, facecolor="tab:blue"),
    interactive=True,
    drag_from_anywhere=True
)

plt.show()

ポイント解説

ライブラリ

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.widgets import SpanSelector
from lmfit import Model

def gaussian(x, amp, cen, wid):
    return (amp / (np.sqrt(2*np.pi) * wid)) * np.exp(-(x-cen)**2 / (2*wid**2))

def fit_gaussian(x, data):
    mod = Model(gaussian) 
    pars = mod.make_params(amp=1, cen=1, wid=1)
    result = mod.fit(data, pars, x=x)
    return result

from matplotlib.widgets import SpanSelector」がmatplotlibのGUI操作に必要なモジュールです。

from lmfit import Model」がフィッティングに使用するモジュールです。
直後にガウシアン関数とその関数をもとにフィッティングする関数を定義しています。

更新用のplt

# あとで更新する
fitted, = ax.plot([], [], c="r", label = "fitted")

フィッティングによりデータを後から更新する(set_data)ために、ここでは先にオブジェクトのみを渡しています。

MatplotlibのGUI用ウィジェット

span = SpanSelector(
    ax,
    fit.select_callback,
    "horizontal",
    useblit=True,
    props=dict(alpha=0.2, facecolor="tab:blue"),
    interactive=True,
    drag_from_anywhere=True
)

SpanSelectorはマウス操作(クリック&ドラッグ・リリース)でx軸を選択するためのウィジェットです。

引数の「fit.select_callback」はマウス操作に対応して実施される関数(Callback関数)を指定しています。
この関数は上部のクラス「Fit」内で定義されています。

また「interactive=True」で選択した領域を操作できるようにしています。

y軸も含めて長方形で選択したい場合は、別ウィジェット(RectangleSelector)を使いましょう。

SpanSelectorの詳細は公式ページで確認できます。

フィッティングとCallback関数

class Fit:
    def __init__(self, ax, df):
        self.ax = ax or plt.gca()
        self.df = df

    def fitting(self, x1, x2):
        xx = self.df.query(f'{x1}<x<{x2}').x
        yy= self.df.query(f'{x1}<x<{x2}').y

        result = fit_gaussian(xx, yy)
        print("-------------Fitting result------------------")
        pars =  list(result.best_values.values())
        print(pars)
        
        fitted.set_data(self.df.x, gaussian(self.df.x, *pars) )
        residual.set_data(self.df.x, self.df.y - gaussian(self.df.x, *pars) )

        fig.canvas.draw()
        fig.canvas.flush_events()   

    def select_callback(self, x1, x2):
        self.fitting(x1, x2)

ここではフィッティングをマウス操作で実施できるようにしています。

init(self, ax, df)
フィッティングに必要なグラフ情報(ax)と元データ(df)はインスタンス作成の引数に設定

select_callback(self, x1, x2)
上記のSpanSelectorによって呼ばれる関数です。選択領域端のx1, x2を使用できます。
ここではx1,x2を引数にして直ぐにself.fittingを呼びます。

fitting(self, x1, x2)
padas dfのqueryで選択範囲のみのデータを切り出します。
result = fit_gaussian(xx, yy)でフィッティング実施し、resultから必要なデータを取り出します。

result.best_valuesはフィッティング変数(ここではgaussian関数で定義したamp, cen, wid)の結果が入っています。
values()で順番に数字を取り出して、list()でリスト化しています。
plotする際には「アスタリスク*」を変数名parsにつけてgasussian(x,amp,cen,wid)へ代入しています。

lmfitでは結果の取り出しに「result.best_fit」や 「result.best_values[‘変数名’]」などの方法もありますが、今回はフィット範囲以外までプロットしたいのでパラメータのみを利用しました。

フィット範囲内のみプロットできればOKなら、
ax.plot(xx, result.best_fit)

ax.plot(xx, fit_line(xx, result.best_values[‘amp’], result.best_values[‘cen’], result.best_values[‘wid’]))
と書くこともできます。

fitted.set_dat
は先に渡したプロットオブジェクトのデータを更新しています。

今後作りたいもの

今回は使用しませんでしたが、「fig.canvas.mpl_connect」を使うと、matplotlibの操作にキーボード操作などを付与できます。例えばキー入力してフィット関数を変えるなど自由度が増えますね。

2022 12/22追記:fig.canvas.mpl_connectを使う例を作成しました。

2023 1/7追記:より実践的なGUIを紹介しました。

まとめ

matplotlibのウィジェットとlmfitを用いてGUI操作で範囲を指定してフィッティングする方法を解説しました。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

コメント

コメントする

CAPTCHA


目次