ちょいめも

物理/Python/Cの雑記帳

python 2次元フィッティング

#2次元フィッティング
#dataが入力されたcsvを入力してフィッティングする
#第一引数:csvファイルパス

import sys
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def func(XY, a, b, x0, y0, c):
    x, y = XY[0:2]
    z = a*(x-x0)**2 + b*(y-y0)**2 + c
    return z.ravel() #1次元に変換する

argvs = sys.argv
argc = len(argvs)

if argc != 2:
    print('arg1 : csv file path\n')
    input('Press any key to exit\n')
    sys.exit()

data = np.loadtxt(argvs[1], delimiter=',')
y = np.linspace(0, data.shape[0]-1, data.shape[0])
x = np.linspace(0, data.shape[1]-1, data.shape[1])
X, Y = np.meshgrid(x, y)

#initial guesses for a,b,c:
parameter_initial = (1, 1, 50, 50, 0)
param, cov = curve_fit(func, (X, Y), data.ravel(), p0=parameter_initial)
print(param)

#result
fitting_data = func((X, Y), param[0], param[1], param[2], param[3], param[4]).reshape(data.shape[0], data.shape[1])

#plot graph
fig = plt.figure()
ax = Axes3D(fig)
#input data
ax.plot_surface(X, Y, data) #plot_surface #plot_wireframe #scatter
#fitting graph
ax.plot_surface(X, Y, fitting_data)
plt.show()