|
- # !/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- @Time : 2020/11/23 18:25
- @Author : Albert Darren
- @Contact : 2563491540@qq.com
- @File : orthogonal_polynomial_least_square_fitting.py
- @Version : Version 1.0.0
- @Description : TODO
- @Created By : PyCharm
- """
- import matplotlib.pyplot as plt
- import numpy as np
- import scipy.optimize as so
- import pylab as mpl
- import sympy as sp
-
- x = sp.symbols('x')
-
-
- def runge(x): # 龙格函数
- return 1 / (1 + 25 * x ** 2)
-
-
- def func(p, x):
- a0, a1, a2, a3 = p
- return a0 + a1 * x + a2 * x * x + a3 * x * x * x
-
-
- def err(p, x, y):
- return func(p, x) - y
-
-
- def calculate(expr_i, expr_j, expr_value, expr_omega):
- ans = 0
- for cnt, v in enumerate(expr_value):
- if isinstance(expr_i, (type(x), type(x * x))):
- actual_expr_i = expr_i.subs(x, v[0])
- elif expr_i == 1: # which means 1 or 0
- actual_expr_i = 1
- else:
- actual_expr_i = v[1]
- if isinstance(expr_j, (type(x), type(x * x))):
- actual_expr_j = expr_j.subs(x, v[0])
- else: # which means 1
- actual_expr_j = 1
- if type(expr_omega) == type(1):
- ans = ans + expr_omega * actual_expr_i * actual_expr_j
- else:
- ans = ans + expr_omega[cnt] * actual_expr_i * actual_expr_j
-
- return ans
-
-
- def least_squares(_psi, _value, _omega):
- g = np.empty((len(_psi), len(_psi)))
- d = np.empty(len(_psi))
- for i, expr_i in enumerate(_psi):
- for j, expr_j in enumerate(_psi):
- g[i][j] = calculate(expr_i, expr_j, _value, _omega)
- d[i] = calculate(0, _psi[i], _value, _omega)
- a = np.linalg.solve(g, d.T) # Oh, I love solve()!
- ls_f = 0
- for i, element in enumerate(a):
- ls_f += element * _psi[i]
- return ls_f
-
-
- def draw(x_: np.ndarray, f: np.ndarray, fn):
- x_range = np.linspace(-1, 1, 100)
- y_range = [fn.subs(x, i) for i in x_range]
- plt.plot(x_range, y_range, label='三次曲线拟合函数S(x)', color='green')
- # 画出连续的runge函数
- runge_range = [runge(i) for i in x_range]
- plt.plot(x_range, runge_range, label=r'f(x)=$\frac{1}{1+25x^{2}}$', color='yellow')
- # 画出插值结点散点图
- plt.scatter(x_, f, label="数据点", color="red")
- # plt.title("%s次拉格朗日插值法结果" % n)
- plt.title("最小二乘法三次曲线拟合")
- mpl.rcParams['font.sans-serif'] = ['SimHei']
- mpl.rcParams['axes.unicode_minus'] = False
- plt.legend(loc="upper right")
- plt.show()
-
-
- # 调用系统函数
- # if __name__ == '__main__':
- # psi = [pow(x, i) for i in range(4)]
- # value = np.array([[-1 + 0.2 * i, runge(-1 + 0.2 * i)] for i in range(11)])
- # arg_f = (np.array([x[0] for x in value[:, :1]]), np.array([y[0] for y in value[:, 1:2]]))
- # coef_ls = so.leastsq(err, [1, 1, 1, 1], args=arg_f)
- # print("拟合系数为:\n{}".format(coef_ls))
- # system_ls_f_x = 0
- # for i, element in enumerate(coef_ls[0]):
- # system_ls_f_x = system_ls_f_x + element * psi[i]
- # print("3次拟合曲线方程为:\ny={}".format(system_ls_f_x))
- # draw(value[:, :1], value[:, 1:2], system_ls_f_x)
-
- # 自编代码
- # if __name__ == '__main__':
- # psi = [pow(x, i) for i in range(4)]
- # value = np.array([[-1 + 0.2 * i, runge(-1 + 0.2 * i)] for i in range(11)])
- # omega = 1
- # # omega=[2,1,3,1,1]
- # ls_f_x = least_squares(psi, value, omega)
- # print("3次拟合曲线方程为:\ny={}".format(ls_f_x))
- # draw(value[:, :1], value[:, 1:2], ls_f_x)
- """
- # using system functions
- def func(p,x):
- a0,a1,a2,a3 = p
- return a0+a1*x+a2*x*x+a3*x*x*x
- def err(p,x,y):
- return func(p,x)-y
- arg_f=(np.array([x[0] for x in value[:,:1]]),np.array([y[0] for y in value[:,1:2]]))
- coef_ls = so.leastsq(err, [1,1,1,1], args=arg_f)
- print(coef_ls)
- system_ls_f_x=0
- for i,element in enumerate(coef_ls[0]):
- system_ls_f_x=system_ls_f_x+element*psi[i]
- print(system_ls_f_x)
- p1=sp.plot(f_x,ls_f_x,system_ls_f_x,(x,-1,1),show=False)
- p1[1].line_color='r'
- p1[2].line_color='g'
- p1.show()
- """
-
- # problem 2:
- # fig = plt.figure()
- #
- # value = np.array([[0, 1], [0.1, 0.41], [0.2, 0.50], [0.3, 0.61], [0.5, 0.91], [0.8, 2.02], [1.0, 2.46]])
- #
- # ls_f_x = least_squares(psi, value, omega)
- # print_f = sp.lambdify(x, ls_f_x, "numpy")
- # print_x = np.linspace(-1, 1, 100)
- # print_y = print_f(print_x)
- # plt.plot(print_x, print_y, label='x^3')
- #
- # psi = [1, x, x ** 2, x ** 3, x ** 4]
- # ls_f_x = least_squares(psi, value, omega)
- # print_f = sp.lambdify(x, ls_f_x, "numpy")
- # print_y = print_f(print_x)
- # plt.plot(print_x, print_y, label='x^4')
- #
- # psi = [1, x]
- # ls_f_x = least_squares(psi, value, omega)
- # print_f = sp.lambdify(x, ls_f_x, "numpy")
- # print_y = print_f(print_x)
- # plt.plot(print_x, print_y, label='kx+b')
- #
- # plt.scatter(np.array([x[0] for x in value[:, :1]]), np.array([y[0] for y in value[:, 1:2]]), marker='x', label='data')
- # plt.legend(loc='best')
- # plt.savefig('output.svg')
|