背景
由项目中需要根据一些已有数据学习出一个y=ax+b的一元二项式,给定了x,y的一些样本数据,通过梯度下降或最小二乘法做多项式拟合得到a、b,解决该问题时,首先想到的是通过spark mllib去学习,可是结果并不理想:少量的文档,参数也很难调整。于是转变了解决问题的方式:采用了最小二乘法做多项式拟合。
最小二乘法多项式拟合描述下: (以下参考:https://blog.csdn.net/funnyrand/article/details/46742561)
假设给定的数据点和其对应的函数值为 (x1, y1), (x2, y2), ... (xm, ym),需要做的就是得到一个多项式函数f(x) = a0 * x + a1 * pow(x, 2) + .. + an * pow(x, n),使其对所有给定x所计算出的f(x)与实际对应的y值的差的平方和最小,也就是计算多项式的各项系数 a0, a1, ... an.
根据最小二乘法的原理,该问题可转换为求以下线性方程组的解:Ga = B,
所以从编程的角度来说需要做两件事情:
1)确定线性方程组的各个系数:
确定系数比较简单,对给定的 (x1, y1), (x2, y2), ... (xm, ym) 做相应的计算即可,相关代码:
private void compute() { ... }
2)解线性方程组:
解线性方程组稍微复杂,这里用到了高斯消元法,基本思想是通过递归做矩阵转换,逐渐减少求解的多项式系数的个数,相关代码:
private double[] calcLinearEquation(double[][] a, double[] b) { ... }
Java代码
1 public class JavaLeastSquare { 2 private double[] x; 3 private double[] y; 4 private double[] weight; 5 private int n; 6 private double[] coefficient; 7 8 /** 9 * Constructor method. 10 * @param x Array of x 11 * @param y Array of y 12 * @param n The order of polynomial 13 */ 14 public JavaLeastSquare(double[] x, double[] y, int n) { 15 if (x == null || y == null || x.length < 2 || x.length != y.length 16 || n < 2) { 17 throw new IllegalArgumentException( 18 "IllegalArgumentException occurred."); 19 } 20 this.x = x; 21 this.y = y; 22 this.n = n; 23 weight = new double[x.length]; 24 for (int i = 0; i < x.length; i++) { 25 weight[i] = 1; 26 } 27 compute(); 28 } 29 30 /** 31 * Constructor method. 32 * @param x Array of x 33 * @param y Array of y 34 * @param weight Array of weight 35 * @param n The order of polynomial 36 */ 37 public JavaLeastSquare(double[] x, double[] y, double[] weight, int n) { 38 if (x == null || y == null || weight == null || x.length < 2 39 || x.length != y.length || x.length != weight.length || n < 2) { 40 throw new IllegalArgumentException( 41 "IllegalArgumentException occurred."); 42 } 43 this.x = x; 44 this.y = y; 45 this.n = n; 46 this.weight = weight; 47 compute(); 48 } 49 50 /** 51 * Get coefficient of polynomial. 52 * @return coefficient of polynomial 53 */ 54 public double[] getCoefficient() { 55 return coefficient; 56 } 57 58 /** 59 * Used to calculate value by given x. 60 * @param x x 61 * @return y 62 */ 63 public double fit(double x) { 64 if (coefficient == null) { 65 return 0; 66 } 67 double sum = 0; 68 for (int i = 0; i < coefficient.length; i++) { 69 sum += Math.pow(x, i) * coefficient[i]; 70 } 71 return sum; 72 } 73 74 /** 75 * Use Newton's method to solve equation. 76 * @param y y 77 * @return x 78 */ 79 public double solve(double y) { 80 return solve(y, 1.0d); 81 } 82 83 /** 84 * Use Newton's method to solve equation. 85 * @param y y 86 * @param startX The start point of x 87 * @return x 88 */ 89 public double solve(double y, double startX) { 90 final double EPS = 0.0000001d; 91 if (coefficient == null) { 92 return 0; 93 } 94 double x1 = 0.0d; 95 double x2 = startX; 96 do { 97 x1 = x2; 98 x2 = x1 - (fit(x1) - y) / calcReciprocal(x1); 99 } while (Math.abs((x1 - x2)) > EPS);100 return x2;101 }102 103 /*104 * Calculate the reciprocal of x.105 * @param x x106 * @return the reciprocal of x107 */108 private double calcReciprocal(double x) {109 if (coefficient == null) {110 return 0;111 }112 double sum = 0;113 for (int i = 1; i < coefficient.length; i++) {114 sum += i * Math.pow(x, i - 1) * coefficient[i];115 }116 return sum;117 }118 119 /*120 * This method is used to calculate each elements of augmented matrix.121 */122 private void compute() {123 if (x == null || y == null || x.length <= 1 || x.length != y.length124 || x.length < n || n < 2) {125 return;126 }127 double[] s = new double[(n - 1) * 2 + 1];128 for (int i = 0; i < s.length; i++) {129 for (int j = 0; j < x.length; j++) {130 s[i] += Math.pow(x[j], i) * weight[j];131 }132 }133 double[] b = new double[n];134 for (int i = 0; i < b.length; i++) {135 for (int j = 0; j < x.length; j++) {136 b[i] += Math.pow(x[j], i) * y[j] * weight[j];137 }138 }139 double[][] a = new double[n][n];140 for (int i = 0; i < n; i++) {141 for (int j = 0; j < n; j++) {142 a[i][j] = s[i + j];143 }144 }145 146 // Now we need to calculate each coefficients of augmented matrix147 coefficient = calcLinearEquation(a, b);148 }149 150 /*151 * Calculate linear equation.152 * The matrix equation is like this: Ax=B153 * @param a two-dimensional array154 * @param b one-dimensional array155 * @return x, one-dimensional array156 */157 private double[] calcLinearEquation(double[][] a, double[] b) {158 if (a == null || b == null || a.length == 0 || a.length != b.length) {159 return null;160 }161 162 for (double[] x : a) {163 if (x == null || x.length != a.length)164 return null;165 }166 167 int len = a.length - 1;168 double[] result = new double[a.length];169 170 if (len == 0) {171 result[0] = b[0] / a[0][0];172 return result;173 }174 175 double[][] aa = new double[len][len];176 double[] bb = new double[len];177 int posx = -1, posy = -1;178 for (int i = 0; i <= len; i++) {179 for (int j = 0; j <= len; j++)180 if (a[i][j] != 0.0d) {181 posy = j;182 break;183 }184 if (posy != -1) {185 posx = i;186 break;187 }188 }189 if (posx == -1) {190 return null;191 }192 193 int count = 0;194 for (int i = 0; i <= len; i++) {195 if (i == posx) {196 continue;197 }198 bb[count] = b[i] * a[posx][posy] - b[posx] * a[i][posy];199 int count2 = 0;200 for (int j = 0; j <= len; j++) {201 if (j == posy) {202 continue;203 }204 aa[count][count2] = a[i][j] * a[posx][posy] - a[posx][j] * a[i][posy];205 count2++;206 }207 count++;208 }209 210 // Calculate sub linear equation211 double[] result2 = calcLinearEquation(aa, bb);212 213 // After sub linear calculation, calculate the current coefficient214 double sum = b[posx];215 count = 0;216 for (int i = 0; i <= len; i++) {217 if (i == posy) {218 continue;219 }220 sum -= a[posx][i] * result2[count];221 result[i] = result2[count];222 count++;223 }224 result[posy] = sum / a[posx][posy];225 return result;226 }227 228 public static void main(String[] args) {229 JavaLeastSquare eastSquareMethod = new JavaLeastSquare(230 new double[]{231 2, 14, 20, 25, 26, 34,232 47, 87, 165, 265, 365, 465,233 565, 665234 },235 new double[]{236 0.7 * 2 + 20 + 0.4,237 0.7 * 14 + 20 + 0.5,238 0.7 * 20 + 20 + 3.4,239 0.7 * 25 + 20 + 5.8,240 0.7 * 26 + 20 + 8.27,241 0.7 * 34 + 20 + 0.4,242 243 0.7 * 47 + 20 + 0.1,244 0.7 * 87 + 20,245 0.7 * 165 + 20,246 0.7 * 265 + 20,247 0.7 * 365 + 20,248 0.7 * 465 + 20,249 250 0.7 * 565 + 20,251 0.7 * 665 + 20252 },253 2);254 255 double[] coefficients = eastSquareMethod.getCoefficient();256 for (double c : coefficients) {257 System.out.println(c);258 }259 260 // 测试261 System.out.println(eastSquareMethod.fit(4));262 }263 }
输出结果:
com.datangmobile.biz.leastsquare.JavaLeastSquare
22.279668814676290.695247590744820325.06065917765557Process finished with exit code 0
使用开源库
也可使用Apache开源库commons math(http://commons.apache.org/proper/commons-math/userguide/fitting.html),提供的功能更强大:
org.apache.commons commons-math3 3.5
实现代码:
import org.apache.commons.math3.fitting.PolynomialCurveFitter;import org.apache.commons.math3.fitting.WeightedObservedPoints;public class WeightedObservedPointsTest { public static void main(String[] args) { final WeightedObservedPoints obs = new WeightedObservedPoints(); obs.add(2, 0.7 * 2 + 20 + 0.4); obs.add(12, 0.7 * 12 + 20 + 0.3); obs.add(32, 0.7 * 32 + 20 + 3.4); obs.add(34 , 0.7 * 34 + 20 + 5.8); obs.add(58 , 0.7 * 58 + 20 + 8.4); obs.add(43 , 0.7 * 43 + 20 + 0.28); obs.add(27 , 0.7 * 27 + 20 + 0.4); // Instantiate a two-degree polynomial fitter. final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(2); // Retrieve fitted parameters (coefficients of the polynomial function). final double[] coeff = fitter.fit(obs.toList()); for (double c : coeff) { System.out.println(c); } }}
测试输出结果:
20.47425047847121
0.67497440630351120.002523043547711147Process finished with exit code 0
使用org.ujmp(矩阵)实现最小二乘法:
pom.xml中需要引入org.ujmp
4.0.0 com.dtgroup dtgroup 0.0.1-SNAPSHOT limaven aliyun maven http://maven.aliyun.com/nexus/content/groups/public/ default true false org.ujmp ujmp-core 0.3.0
java代码:
/** * 采用最小二乘法多项式拟合方式,获取多项式的系数。 * @param sampleCount 采样点个数 * @param fetureCount 多项式的系数 * @param samples 采样点集合 * **/ private static void leastsequare(int sampleCount, int fetureCout, Listsamples) { // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵 Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout); for (int i = 0; i < samples.size(); i++) { matrixX.setAsDouble(samples.get(i).getX(), i, 1); } // System.out.println(matrixX); System.out.println("--------------------------------------"); // 构件 2*2矩阵 存储X Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1); for (int i = 0; i < samples.size(); i++) { matrixY.setAsDouble(samples.get(i).getY(), i, 0); } // System.out.println(matrixY); // 对X进行转置 Matrix matrixXTrans = matrixX.transpose(); // System.out.println(matrixXTrans); // 乘积运算:x*转转置后x:matrixXTrans*matrixX Matrix matrixMtimes = matrixXTrans.mtimes(matrixX); System.out.println(matrixMtimes); System.out.println("--------------------------------------"); // 求逆 Matrix matrixMtimesInv = matrixMtimes.inv(); System.out.println(matrixMtimesInv); // x转置后结果*求逆结果 System.out.println("--------------------------------------"); Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans); System.out.println(matrixMtimesInvMtimes); System.out.println("--------------------------------------"); Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY); System.out.println(theta); }
测试代码:
public static void main(String[] args) { /** * y=ax+b * * a(0,1] b[5,20] * * x[0,500] y>=5 */ // y= 0.8d*x+15 // 当x不变动时,y对应有多个值;此时把y求均值。 Listsamples = new ArrayList (); samples.add(new Sample(0.8d * 1 + 15 + 1, 1d)); samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d)); samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d)); samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d)); samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d)); samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d)); samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d)); samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d)); samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d)); int sampleCount = samples.size(); int fetureCout = 2; leastsequare(sampleCount, fetureCout, samples); }
过滤样本中的噪点:
public static void main(String[] args) { /** * y=ax+b * * a(0,1] b[5,20] * * x[0,500] y>=5 */ // y= 0.8d*x+15 // 当x不变动时,y对应有多个值;此时把y求均值。 Listsamples = new ArrayList (); samples.add(new Sample(0.8d * 1 + 15 + 1, 1d)); samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d)); samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d)); samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d)); samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d)); samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d)); samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d)); samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d)); samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d)); // samples = filterSample(samples); sortSample(samples); FilterSampleByGradientResult result = filterSampleByGradient(0, samples); while (result.isComplete() == false) { List newSamples=result.getSamples(); sortSample(newSamples); result = filterSampleByGradient(result.getIndex(), newSamples); } samples = result.getSamples(); for (Sample sample : samples) { System.out.println(sample); } int sampleCount = samples.size(); int fetureCout = 2; leastsequare(sampleCount, fetureCout, samples); } /** * 对采样点进行排序,按照x排序,升序排列 * @param samples 采样点集合 * **/ private static void sortSample(List samples) { samples.sort(new Comparator () { public int compare(Sample o1, Sample o2) { if (o1.getX() > o2.getX()) { return 1; } else if (o1.getX() <= o2.getX()) { return -1; } return 0; } }); } /** * 过滤采样点中的噪点(采样过滤方式:double theta=(y2-y1)/(x2-x1),theta就是一个斜率,根据该值范围来过滤。) * @param index 记录上次过滤索引 * @param samples 采样点集合(将从其中过滤掉噪点) * **/ private static FilterSampleByGradientResult filterSampleByGradient(int index, List samples) { int sampleSize = samples.size(); for (int i = index; i < sampleSize - 1; i++) { double delta_x = samples.get(i).getX() - samples.get(i + 1).getX(); double delta_y = samples.get(i).getY() - samples.get(i + 1).getY(); // 距离小于2米 if (Math.abs(delta_x) < 1) { double newY = (samples.get(i).getY() + samples.get(i + 1).getY()) / 2; double newX = samples.get(i).getX(); samples.remove(i); samples.remove(i + 1); samples.add(new Sample(newY, newX)); return new FilterSampleByGradientResult(false, i, samples); } else { double gradient = delta_y / delta_x; if (gradient > 1.5) { if (i == 0) { // double newY = (samples.get(i).getY() + samples.get(i // + 1).getY()) / 2; // double newX = (samples.get(i).getX() + samples.get(i // + 1).getX()) / 2; // samples.remove(i); // samples.add(new Sample(newY, newX)); } else { samples.remove(i + 1); } return new FilterSampleByGradientResult(false, i, samples); } } } return new FilterSampleByGradientResult(true, 0, samples); }
使用距离来处理过滤:
private static ListfilterSample(List samples) { // x={x1,x2,x3...xn} // u=E(x) ---x的期望(均值)为 u // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) // 6为x的标准差,标准差=sqrt(方差) // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 // 另外一种方案,对x/y都做上边的处理,之后如果两个结果为and 或者 or操作来选取是否剔除。 // 用点的方式来过滤数据,求出一个中值点,求其他点到该点的距离。 int sampleCount = samples.size(); double sumX = 0d; double sumY = 0d; for (Sample sample : samples) { sumX += sample.getX(); sumY += sample.getY(); } // 求中心点 double centerX = (sumX / sampleCount); double centerY = (sumY / sampleCount); List distanItems = new ArrayList (); // 计算出所有点距离该中心点的距离 for (int i = 0; i < samples.size(); i++) { Sample sample = samples.get(i); Double xyPow2 = Math.pow(sample.getX() - centerX, 2) + Math.pow(sample.getY() - centerY, 2); distanItems.add(Math.sqrt(xyPow2)); } // 以下对根据距离(所有点距离中心点的距离)进行筛选 double sumDistan = 0d; double distanceU = 0d; for (Double distance : distanItems) { sumDistan += distance; } distanceU = sumDistan / sampleCount; double deltaPowSum = 0d; double distanceTheta = 0d; // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) for (Double distance : distanItems) { deltaPowSum += Math.pow((distance - distanceU), 2); } distanceTheta = Math.sqrt(deltaPowSum); // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 double minDistance = distanceU - 0.5 * distanceTheta; double maxDistance = distanceU + 0.5 * distanceTheta; List willbeRemoveIdxs = new ArrayList (); for (int i = distanItems.size() - 1; i >= 0; i--) { Double distance = distanItems.get(i); if (distance <= minDistance || distance >= maxDistance) { willbeRemoveIdxs.add(i); System.out.println("will be remove " + i); } } for (int willbeRemoveIdx : willbeRemoveIdxs) { samples.remove(willbeRemoveIdx); } return samples; }
实际业务测试:
package com.zjanalyse.spark.maths;import java.util.ArrayList;import java.util.Comparator;import java.util.List;import org.ujmp.core.DenseMatrix;import org.ujmp.core.Matrix;public class LastSquare { /** * y=ax+b a(0,1] b[5,20] x[0,500] y>=5 */ public static void main(String[] args) { // y= 0.8d*x+15 // 当x不变动时,y对应有多个值;此时把y求均值。 Listsamples = new ArrayList (); samples.add(new Sample(0.8d * 11 + 15 + 1, 11d)); samples.add(new Sample(0.8d * 24 + 15 + 0.8, 24d)); samples.add(new Sample(0.8d * 33 + 15 + 0.7, 33d)); samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d)); samples.add(new Sample(0.8d * 47 + 15 + 0.3, 47d)); samples.add(new Sample(0.8d * 60 + 15 + 0.4, 60d)); samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d)); samples.add(new Sample(0.8d * 57 + 15 + 0.3, 57d)); samples.add(new Sample(0.8d * 70 + 60 + 0.3, 70d)); samples.add(new Sample(0.8d * 80 + 60 + 0.3, 80d)); samples.add(new Sample(0.8d * 40 + 30 + 0.3, 40d)); sortSample(samples); System.out.println("原始样本数据"); for (Sample sample : samples) { System.out.println(sample); } System.out.println("开始“所有点”通过“业务数据取值范围”剔除:"); // 按照业务过滤。。。 filterByBusiness(samples); System.out.println("结束“所有点”通过“业务数据取值范围”剔除:"); for (Sample sample : samples) { System.out.println(sample); } int sampleCount = samples.size(); int fetureCout = 2; System.out.println("第一次拟合。。。"); Matrix theta = leastsequare(sampleCount, fetureCout, samples); double wear_loss = theta.getAsDouble(0, 0); double path_loss = theta.getAsDouble(1, 0); System.out.println("wear loss " + wear_loss); System.out.println("path loss " + path_loss); System.out.println("开始“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:"); samples = filterSample(wear_loss, path_loss, samples); System.out.println("结束“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:"); for (Sample sample : samples) { System.out.println(sample); } System.out.println("第二次拟合。。。"); sampleCount = samples.size(); fetureCout = 2; if (sampleCount >= 2) { theta = leastsequare(sampleCount, fetureCout, samples); wear_loss = theta.getAsDouble(0, 0); path_loss = theta.getAsDouble(1, 0); System.out.println("wear loss " + wear_loss); System.out.println("path loss " + path_loss); } System.out.println("complete..."); } /** * 按照业务过滤有效值范围 */ private static void filterByBusiness(List samples) { for (int i = 0; i < samples.size(); i++) { double x = samples.get(i).getX(); double y = samples.get(i).getY(); if (x >= 500) { System.out.println(x + " x值超出有效值范围[0,500)"); samples.remove(i); i--; } // y= 0.8d*x+15 else if (y < 0 * x + 5 || y > 1 * x + 30) { System.out.println( y + " y值超出有效值范围[(0*x+5),(1*x+30)]其中x=" + x + ",也就是:[" + (0 * x + 5) + "," + (1 * x + 30) + ")"); samples.remove(i); i--; } } } /** * Description 点到直线的距离 * * @param x1 * 点横坐标 * @param y1 * 点纵坐标 * @param A * 直线方程一般式系数A * @param B * 直线方程一般式系数B * @param C * 直线方程一般式系数C * @return 点到之间的距离 * @see 点0,1到之前y=x+0的距离 * double distance = getDistanceOfPerpendicular(0,0, -1, 1, 0); * System.out.println(distance); */ private static double getDistanceOfPerpendicular(double x1, double y1, double A, double B, double C) { double distance = Math.abs((A * x1 + B * y1 + C) / Math.sqrt(A * A + B * B)); return distance; } private static List filterSample(double wear_loss, double path_loss, List samples) { // x={x1,x2,x3...xn} // u=E(x) ---x的期望(均值)为 u // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) // 6为x的标准差,标准差=sqrt(方差) // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 // 求出所有点距离第一次拟合结果的直线方程的距离 int sampleCount = samples.size(); List distanItems = new ArrayList (); // 计算出所有点距离该中心点的距离 for (int i = 0; i < samples.size(); i++) { Sample sample = samples.get(i); double distance = getDistanceOfPerpendicular(sample.getX(), sample.getY(), path_loss, -1, wear_loss); distanItems.add(Math.sqrt(distance)); } // 以下对根据距离(所有点距离中心点的距离)进行筛选 double sumDistan = 0d; double distanceU = 0d; for (Double distance : distanItems) { sumDistan += distance; } distanceU = sumDistan / sampleCount; double deltaPowSum = 0d; double distanceTheta = 0d; // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) for (Double distance : distanItems) { deltaPowSum += Math.pow((distance - distanceU), 2); } distanceTheta = Math.sqrt(deltaPowSum); // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 double minDistance = distanceU - 0.25 * distanceTheta; double maxDistance = distanceU + 0.25 * distanceTheta; List willbeRemoveIdxs = new ArrayList (); for (int i = distanItems.size() - 1; i >= 0; i--) { Double distance = distanItems.get(i); if (distance <= minDistance || distance >= maxDistance) { System.out.println(distance + " out of range [" + minDistance + "," + maxDistance + "]"); willbeRemoveIdxs.add(i); } else { System.out.println(distance); } } for (int willbeRemoveIdx : willbeRemoveIdxs) { Sample sample = samples.get(willbeRemoveIdx); System.out.println("remove " + sample); samples.remove(willbeRemoveIdx); } return samples; } /** * 对采样点进行排序,按照x排序,升序排列 * * @param samples * 采样点集合 **/ private static void sortSample(List samples) { samples.sort(new Comparator () { public int compare(Sample o1, Sample o2) { if (o1.getX() > o2.getX()) { return 1; } else if (o1.getX() <= o2.getX()) { return -1; } return 0; } }); } /** * Description 采用最小二乘法多项式拟合方式,获取多项式的系数。 * * @param sampleCount * 采样点个数 * @param fetureCount * 多项式的系数 * @param samples * 采样点集合 **/ private static Matrix leastsequare(int sampleCount, int fetureCout, List samples) { // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵 Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout); for (int i = 0; i < samples.size(); i++) { matrixX.setAsDouble(samples.get(i).getX(), i, 1); } // System.out.println(matrixX); // System.out.println("--------------------------------------"); // 构件 2*2矩阵 存储X Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1); for (int i = 0; i < samples.size(); i++) { matrixY.setAsDouble(samples.get(i).getY(), i, 0); } // System.out.println(matrixY); // 对X进行转置 Matrix matrixXTrans = matrixX.transpose(); // System.out.println(matrixXTrans); // 乘积运算:x*转转置后x:matrixXTrans*matrixX Matrix matrixMtimes = matrixXTrans.mtimes(matrixX); // System.out.println(matrixMtimes); // System.out.println("--------------------------------------"); // 求逆 Matrix matrixMtimesInv = matrixMtimes.inv(); // System.out.println(matrixMtimesInv); // x转置后结果*求逆结果 // System.out.println("--------------------------------------"); Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans); // System.out.println(matrixMtimesInvMtimes); // System.out.println("--------------------------------------"); Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY); // System.out.println(theta); return theta; }}