import sys
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def func(XY, a, b, x0, y0, c):
x, y = XY[0:2]
z = a*(x-x0)**2 + b*(y-y0)**2 + c
return z.ravel()
argvs = sys.argv
argc = len(argvs)
if argc != 2:
print('arg1 : csv file path\n')
input('Press any key to exit\n')
sys.exit()
data = np.loadtxt(argvs[1], delimiter=',')
y = np.linspace(0, data.shape[0]-1, data.shape[0])
x = np.linspace(0, data.shape[1]-1, data.shape[1])
X, Y = np.meshgrid(x, y)
parameter_initial = (1, 1, 50, 50, 0)
param, cov = curve_fit(func, (X, Y), data.ravel(), p0=parameter_initial)
print(param)
fitting_data = func((X, Y), param[0], param[1], param[2], param[3], param[4]).reshape(data.shape[0], data.shape[1])
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(X, Y, data)
ax.plot_surface(X, Y, fitting_data)
plt.show()