博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
最小二乘法多项式拟合的Java实现
阅读量:6831 次
发布时间:2019-06-26

本文共 27112 字,大约阅读时间需要 90 分钟。

背景

由项目中需要根据一些已有数据学习出一个y=ax+b的一元二项式,给定了x,y的一些样本数据,通过梯度下降或最小二乘法做多项式拟合得到a、b,解决该问题时,首先想到的是通过spark mllib去学习,可是结果并不理想:少量的文档,参数也很难调整。于是转变了解决问题的方式:采用了最小二乘法做多项式拟合。

最小二乘法多项式拟合描述下: (以下参考:https://blog.csdn.net/funnyrand/article/details/46742561)

假设给定的数据点和其对应的函数值为 (x1, y1), (x2, y2), ... (xm, ym),需要做的就是得到一个多项式函数f(x) = a0 * x + a1 * pow(x, 2) + .. + an * pow(x, n),使其对所有给定x所计算出的f(x)与实际对应的y值的差的平方和最小,也就是计算多项式的各项系数 a0, a1, ... an. 

根据最小二乘法的原理,该问题可转换为求以下线性方程组的解:Ga = B

所以从编程的角度来说需要做两件事情:

1)确定线性方程组的各个系数:

确定系数比较简单,对给定的 (x1, y1), (x2, y2), ... (xm, ym) 做相应的计算即可,相关代码:

private void compute() {
  ... }

2)解线性方程组:

解线性方程组稍微复杂,这里用到了高斯消元法,基本思想是通过递归做矩阵转换,逐渐减少求解的多项式系数的个数,相关代码:

private double[] calcLinearEquation(double[][] a, double[] b) {
  ... }

Java代码

1 public class JavaLeastSquare {  2     private double[] x;  3     private double[] y;  4     private double[] weight;  5     private int n;  6     private double[] coefficient;  7   8     /**  9      * Constructor method. 10      * @param x Array of x 11      * @param y Array of y 12      * @param n The order of polynomial 13      */ 14     public JavaLeastSquare(double[] x, double[] y, int n) { 15         if (x == null || y == null || x.length < 2 || x.length != y.length 16                 || n < 2) { 17             throw new IllegalArgumentException( 18                     "IllegalArgumentException occurred."); 19         } 20         this.x = x; 21         this.y = y; 22         this.n = n; 23         weight = new double[x.length]; 24         for (int i = 0; i < x.length; i++) { 25             weight[i] = 1; 26         } 27         compute(); 28     } 29  30     /** 31      * Constructor method. 32      * @param x      Array of x 33      * @param y      Array of y 34      * @param weight Array of weight 35      * @param n      The order of polynomial 36      */ 37     public JavaLeastSquare(double[] x, double[] y, double[] weight, int n) { 38         if (x == null || y == null || weight == null || x.length < 2 39                 || x.length != y.length || x.length != weight.length || n < 2) { 40             throw new IllegalArgumentException( 41                     "IllegalArgumentException occurred."); 42         } 43         this.x = x; 44         this.y = y; 45         this.n = n; 46         this.weight = weight; 47         compute(); 48     } 49  50     /** 51      * Get coefficient of polynomial. 52      * @return coefficient of polynomial 53      */ 54     public double[] getCoefficient() { 55         return coefficient; 56     } 57  58     /** 59      * Used to calculate value by given x. 60      * @param x x 61      * @return y 62      */ 63     public double fit(double x) { 64         if (coefficient == null) { 65             return 0; 66         } 67         double sum = 0; 68         for (int i = 0; i < coefficient.length; i++) { 69             sum += Math.pow(x, i) * coefficient[i]; 70         } 71         return sum; 72     } 73  74     /** 75      * Use Newton's method to solve equation. 76      * @param y y 77      * @return x 78      */ 79     public double solve(double y) { 80         return solve(y, 1.0d); 81     } 82  83     /** 84      * Use Newton's method to solve equation. 85      * @param y      y 86      * @param startX The start point of x 87      * @return x 88      */ 89     public double solve(double y, double startX) { 90         final double EPS = 0.0000001d; 91         if (coefficient == null) { 92             return 0; 93         } 94         double x1 = 0.0d; 95         double x2 = startX; 96         do { 97             x1 = x2; 98             x2 = x1 - (fit(x1) - y) / calcReciprocal(x1); 99         } while (Math.abs((x1 - x2)) > EPS);100         return x2;101     }102 103     /*104      * Calculate the reciprocal of x.105      * @param x x106      * @return the reciprocal of x107      */108     private double calcReciprocal(double x) {109         if (coefficient == null) {110             return 0;111         }112         double sum = 0;113         for (int i = 1; i < coefficient.length; i++) {114             sum += i * Math.pow(x, i - 1) * coefficient[i];115         }116         return sum;117     }118 119     /*120      * This method is used to calculate each elements of augmented matrix.121      */122     private void compute() {123         if (x == null || y == null || x.length <= 1 || x.length != y.length124                 || x.length < n || n < 2) {125             return;126         }127         double[] s = new double[(n - 1) * 2 + 1];128         for (int i = 0; i < s.length; i++) {129             for (int j = 0; j < x.length; j++) {130                 s[i] += Math.pow(x[j], i) * weight[j];131             }132         }133         double[] b = new double[n];134         for (int i = 0; i < b.length; i++) {135             for (int j = 0; j < x.length; j++) {136                 b[i] += Math.pow(x[j], i) * y[j] * weight[j];137             }138         }139         double[][] a = new double[n][n];140         for (int i = 0; i < n; i++) {141             for (int j = 0; j < n; j++) {142                 a[i][j] = s[i + j];143             }144         }145 146         // Now we need to calculate each coefficients of augmented matrix147         coefficient = calcLinearEquation(a, b);148     }149 150     /*151      * Calculate linear equation.152      * The matrix equation is like this: Ax=B153      * @param a two-dimensional array154      * @param b one-dimensional array155      * @return x, one-dimensional array156      */157     private double[] calcLinearEquation(double[][] a, double[] b) {158         if (a == null || b == null || a.length == 0 || a.length != b.length) {159             return null;160         }161 162         for (double[] x : a) {163             if (x == null || x.length != a.length)164                 return null;165         }166 167         int len = a.length - 1;168         double[] result = new double[a.length];169 170         if (len == 0) {171             result[0] = b[0] / a[0][0];172             return result;173         }174 175         double[][] aa = new double[len][len];176         double[] bb = new double[len];177         int posx = -1, posy = -1;178         for (int i = 0; i <= len; i++) {179             for (int j = 0; j <= len; j++)180                 if (a[i][j] != 0.0d) {181                     posy = j;182                     break;183                 }184             if (posy != -1) {185                 posx = i;186                 break;187             }188         }189         if (posx == -1) {190             return null;191         }192 193         int count = 0;194         for (int i = 0; i <= len; i++) {195             if (i == posx) {196                 continue;197             }198             bb[count] = b[i] * a[posx][posy] - b[posx] * a[i][posy];199             int count2 = 0;200             for (int j = 0; j <= len; j++) {201                 if (j == posy) {202                     continue;203                 }204                 aa[count][count2] = a[i][j] * a[posx][posy] - a[posx][j] * a[i][posy];205                 count2++;206             }207             count++;208         }209 210         // Calculate sub linear equation211         double[] result2 = calcLinearEquation(aa, bb);212 213         // After sub linear calculation, calculate the current coefficient214         double sum = b[posx];215         count = 0;216         for (int i = 0; i <= len; i++) {217             if (i == posy) {218                 continue;219             }220             sum -= a[posx][i] * result2[count];221             result[i] = result2[count];222             count++;223         }224         result[posy] = sum / a[posx][posy];225         return result;226     }227 228     public static void main(String[] args) {229         JavaLeastSquare eastSquareMethod = new JavaLeastSquare(230                 new double[]{231                         2, 14, 20, 25, 26, 34,232                         47, 87, 165, 265, 365, 465,233                         565, 665234                 },235                 new double[]{236                         0.7 * 2 + 20 + 0.4,237                         0.7 * 14 + 20 + 0.5,238                         0.7 * 20 + 20 + 3.4,239                         0.7 * 25 + 20 + 5.8,240                         0.7 * 26 + 20 + 8.27,241                         0.7 * 34 + 20 + 0.4,242 243                         0.7 * 47 + 20 + 0.1,244                         0.7 * 87 + 20,245                         0.7 * 165 + 20,246                         0.7 * 265 + 20,247                         0.7 * 365 + 20,248                         0.7 * 465 + 20,249 250                         0.7 * 565 + 20,251                         0.7 * 665 + 20252                 },253                 2);254 255         double[] coefficients = eastSquareMethod.getCoefficient();256         for (double c : coefficients) {257             System.out.println(c);258         }259 260         // 测试261         System.out.println(eastSquareMethod.fit(4));262     }263 }

输出结果:

com.datangmobile.biz.leastsquare.JavaLeastSquare

22.27966881467629
0.6952475907448203
25.06065917765557

Process finished with exit code 0

使用开源库

也可使用Apache开源库commons math(http://commons.apache.org/proper/commons-math/userguide/fitting.html),提供的功能更强大:

org.apache.commons
commons-math3
3.5

实现代码:

import org.apache.commons.math3.fitting.PolynomialCurveFitter;import org.apache.commons.math3.fitting.WeightedObservedPoints;public class WeightedObservedPointsTest {    public static void main(String[] args) {        final WeightedObservedPoints obs = new WeightedObservedPoints();        obs.add(2,  0.7 * 2 + 20 + 0.4);        obs.add(12,  0.7 * 12 + 20 + 0.3);        obs.add(32,  0.7 * 32 + 20 + 3.4);        obs.add(34 ,  0.7 * 34 + 20 + 5.8);        obs.add(58 , 0.7 * 58 + 20 + 8.4);        obs.add(43 , 0.7 * 43 + 20 + 0.28);        obs.add(27 , 0.7 * 27 + 20 + 0.4);        // Instantiate a two-degree polynomial fitter.        final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(2);        // Retrieve fitted parameters (coefficients of the polynomial function).        final double[] coeff = fitter.fit(obs.toList());        for (double c : coeff) {            System.out.println(c);        }    }}

测试输出结果:

20.47425047847121

0.6749744063035112
0.002523043547711147

Process finished with exit code 0

使用org.ujmp(矩阵)实现最小二乘法:

pom.xml中需要引入org.ujmp

4.0.0
com.dtgroup
dtgroup
0.0.1-SNAPSHOT
limaven
aliyun maven
http://maven.aliyun.com/nexus/content/groups/public/
default
true
false
org.ujmp
ujmp-core
0.3.0

java代码:

/**     * 采用最小二乘法多项式拟合方式,获取多项式的系数。     * @param sampleCount 采样点个数     * @param fetureCount 多项式的系数     * @param samples 采样点集合     * **/    private static void leastsequare(int sampleCount, int fetureCout, List
samples) { // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵 Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout); for (int i = 0; i < samples.size(); i++) { matrixX.setAsDouble(samples.get(i).getX(), i, 1); } // System.out.println(matrixX); System.out.println("--------------------------------------"); // 构件 2*2矩阵 存储X Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1); for (int i = 0; i < samples.size(); i++) { matrixY.setAsDouble(samples.get(i).getY(), i, 0); } // System.out.println(matrixY); // 对X进行转置 Matrix matrixXTrans = matrixX.transpose(); // System.out.println(matrixXTrans); // 乘积运算:x*转转置后x:matrixXTrans*matrixX Matrix matrixMtimes = matrixXTrans.mtimes(matrixX); System.out.println(matrixMtimes); System.out.println("--------------------------------------"); // 求逆 Matrix matrixMtimesInv = matrixMtimes.inv(); System.out.println(matrixMtimesInv); // x转置后结果*求逆结果 System.out.println("--------------------------------------"); Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans); System.out.println(matrixMtimesInvMtimes); System.out.println("--------------------------------------"); Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY); System.out.println(theta); }

测试代码:

public static void main(String[] args) {        /**         * y=ax+b         *          * a(0,1] b[5,20]         *          * x[0,500] y>=5         */        // y= 0.8d*x+15        // 当x不变动时,y对应有多个值;此时把y求均值。        List
samples = new ArrayList
(); samples.add(new Sample(0.8d * 1 + 15 + 1, 1d)); samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d)); samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d)); samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d)); samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d)); samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d)); samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d)); samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d)); samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d)); int sampleCount = samples.size(); int fetureCout = 2; leastsequare(sampleCount, fetureCout, samples); }

过滤样本中的噪点:

public static void main(String[] args) {        /**         * y=ax+b         *          * a(0,1] b[5,20]         *          * x[0,500] y>=5         */        // y= 0.8d*x+15        // 当x不变动时,y对应有多个值;此时把y求均值。        List
samples = new ArrayList
(); samples.add(new Sample(0.8d * 1 + 15 + 1, 1d)); samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d)); samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d)); samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d)); samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d)); samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d)); samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d)); samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d)); samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d)); // samples = filterSample(samples); sortSample(samples); FilterSampleByGradientResult result = filterSampleByGradient(0, samples); while (result.isComplete() == false) { List
newSamples=result.getSamples(); sortSample(newSamples); result = filterSampleByGradient(result.getIndex(), newSamples); } samples = result.getSamples(); for (Sample sample : samples) { System.out.println(sample); } int sampleCount = samples.size(); int fetureCout = 2; leastsequare(sampleCount, fetureCout, samples); } /** * 对采样点进行排序,按照x排序,升序排列 * @param samples 采样点集合 * **/ private static void sortSample(List
samples) { samples.sort(new Comparator
() { public int compare(Sample o1, Sample o2) { if (o1.getX() > o2.getX()) { return 1; } else if (o1.getX() <= o2.getX()) { return -1; } return 0; } }); } /** * 过滤采样点中的噪点(采样过滤方式:double theta=(y2-y1)/(x2-x1),theta就是一个斜率,根据该值范围来过滤。) * @param index 记录上次过滤索引 * @param samples 采样点集合(将从其中过滤掉噪点) * **/ private static FilterSampleByGradientResult filterSampleByGradient(int index, List
samples) { int sampleSize = samples.size(); for (int i = index; i < sampleSize - 1; i++) { double delta_x = samples.get(i).getX() - samples.get(i + 1).getX(); double delta_y = samples.get(i).getY() - samples.get(i + 1).getY(); // 距离小于2米 if (Math.abs(delta_x) < 1) { double newY = (samples.get(i).getY() + samples.get(i + 1).getY()) / 2; double newX = samples.get(i).getX(); samples.remove(i); samples.remove(i + 1); samples.add(new Sample(newY, newX)); return new FilterSampleByGradientResult(false, i, samples); } else { double gradient = delta_y / delta_x; if (gradient > 1.5) { if (i == 0) { // double newY = (samples.get(i).getY() + samples.get(i // + 1).getY()) / 2; // double newX = (samples.get(i).getX() + samples.get(i // + 1).getX()) / 2; // samples.remove(i); // samples.add(new Sample(newY, newX)); } else { samples.remove(i + 1); } return new FilterSampleByGradientResult(false, i, samples); } } } return new FilterSampleByGradientResult(true, 0, samples); }

 使用距离来处理过滤:

private static List
filterSample(List
samples) { // x={x1,x2,x3...xn} // u=E(x) ---x的期望(均值)为 u // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) // 6为x的标准差,标准差=sqrt(方差) // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 // 另外一种方案,对x/y都做上边的处理,之后如果两个结果为and 或者 or操作来选取是否剔除。 // 用点的方式来过滤数据,求出一个中值点,求其他点到该点的距离。 int sampleCount = samples.size(); double sumX = 0d; double sumY = 0d; for (Sample sample : samples) { sumX += sample.getX(); sumY += sample.getY(); } // 求中心点 double centerX = (sumX / sampleCount); double centerY = (sumY / sampleCount); List
distanItems = new ArrayList
(); // 计算出所有点距离该中心点的距离 for (int i = 0; i < samples.size(); i++) { Sample sample = samples.get(i); Double xyPow2 = Math.pow(sample.getX() - centerX, 2) + Math.pow(sample.getY() - centerY, 2); distanItems.add(Math.sqrt(xyPow2)); } // 以下对根据距离(所有点距离中心点的距离)进行筛选 double sumDistan = 0d; double distanceU = 0d; for (Double distance : distanItems) { sumDistan += distance; } distanceU = sumDistan / sampleCount; double deltaPowSum = 0d; double distanceTheta = 0d; // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) for (Double distance : distanItems) { deltaPowSum += Math.pow((distance - distanceU), 2); } distanceTheta = Math.sqrt(deltaPowSum); // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 double minDistance = distanceU - 0.5 * distanceTheta; double maxDistance = distanceU + 0.5 * distanceTheta; List
willbeRemoveIdxs = new ArrayList
(); for (int i = distanItems.size() - 1; i >= 0; i--) { Double distance = distanItems.get(i); if (distance <= minDistance || distance >= maxDistance) { willbeRemoveIdxs.add(i); System.out.println("will be remove " + i); } } for (int willbeRemoveIdx : willbeRemoveIdxs) { samples.remove(willbeRemoveIdx); } return samples; }

实际业务测试:

package com.zjanalyse.spark.maths;import java.util.ArrayList;import java.util.Comparator;import java.util.List;import org.ujmp.core.DenseMatrix;import org.ujmp.core.Matrix;public class LastSquare {    /**     * y=ax+b a(0,1] b[5,20] x[0,500] y>=5     */    public static void main(String[] args) {        // y= 0.8d*x+15        // 当x不变动时,y对应有多个值;此时把y求均值。        List
samples = new ArrayList
(); samples.add(new Sample(0.8d * 11 + 15 + 1, 11d)); samples.add(new Sample(0.8d * 24 + 15 + 0.8, 24d)); samples.add(new Sample(0.8d * 33 + 15 + 0.7, 33d)); samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d)); samples.add(new Sample(0.8d * 47 + 15 + 0.3, 47d)); samples.add(new Sample(0.8d * 60 + 15 + 0.4, 60d)); samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d)); samples.add(new Sample(0.8d * 57 + 15 + 0.3, 57d)); samples.add(new Sample(0.8d * 70 + 60 + 0.3, 70d)); samples.add(new Sample(0.8d * 80 + 60 + 0.3, 80d)); samples.add(new Sample(0.8d * 40 + 30 + 0.3, 40d)); sortSample(samples); System.out.println("原始样本数据"); for (Sample sample : samples) { System.out.println(sample); } System.out.println("开始“所有点”通过“业务数据取值范围”剔除:"); // 按照业务过滤。。。 filterByBusiness(samples); System.out.println("结束“所有点”通过“业务数据取值范围”剔除:"); for (Sample sample : samples) { System.out.println(sample); } int sampleCount = samples.size(); int fetureCout = 2; System.out.println("第一次拟合。。。"); Matrix theta = leastsequare(sampleCount, fetureCout, samples); double wear_loss = theta.getAsDouble(0, 0); double path_loss = theta.getAsDouble(1, 0); System.out.println("wear loss " + wear_loss); System.out.println("path loss " + path_loss); System.out.println("开始“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:"); samples = filterSample(wear_loss, path_loss, samples); System.out.println("结束“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:"); for (Sample sample : samples) { System.out.println(sample); } System.out.println("第二次拟合。。。"); sampleCount = samples.size(); fetureCout = 2; if (sampleCount >= 2) { theta = leastsequare(sampleCount, fetureCout, samples); wear_loss = theta.getAsDouble(0, 0); path_loss = theta.getAsDouble(1, 0); System.out.println("wear loss " + wear_loss); System.out.println("path loss " + path_loss); } System.out.println("complete..."); } /** * 按照业务过滤有效值范围 */ private static void filterByBusiness(List
samples) { for (int i = 0; i < samples.size(); i++) { double x = samples.get(i).getX(); double y = samples.get(i).getY(); if (x >= 500) { System.out.println(x + " x值超出有效值范围[0,500)"); samples.remove(i); i--; } // y= 0.8d*x+15 else if (y < 0 * x + 5 || y > 1 * x + 30) { System.out.println( y + " y值超出有效值范围[(0*x+5),(1*x+30)]其中x=" + x + ",也就是:[" + (0 * x + 5) + "," + (1 * x + 30) + ")"); samples.remove(i); i--; } } } /** * Description 点到直线的距离 * * @param x1 * 点横坐标 * @param y1 * 点纵坐标 * @param A * 直线方程一般式系数A * @param B * 直线方程一般式系数B * @param C * 直线方程一般式系数C * @return 点到之间的距离 * @see 点0,1到之前y=x+0的距离
* double distance = getDistanceOfPerpendicular(0,0, -1, 1, 0);
* System.out.println(distance);
*/ private static double getDistanceOfPerpendicular(double x1, double y1, double A, double B, double C) { double distance = Math.abs((A * x1 + B * y1 + C) / Math.sqrt(A * A + B * B)); return distance; } private static List
filterSample(double wear_loss, double path_loss, List
samples) { // x={x1,x2,x3...xn} // u=E(x) ---x的期望(均值)为 u // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) // 6为x的标准差,标准差=sqrt(方差) // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 // 求出所有点距离第一次拟合结果的直线方程的距离 int sampleCount = samples.size(); List
distanItems = new ArrayList
(); // 计算出所有点距离该中心点的距离 for (int i = 0; i < samples.size(); i++) { Sample sample = samples.get(i); double distance = getDistanceOfPerpendicular(sample.getX(), sample.getY(), path_loss, -1, wear_loss); distanItems.add(Math.sqrt(distance)); } // 以下对根据距离(所有点距离中心点的距离)进行筛选 double sumDistan = 0d; double distanceU = 0d; for (Double distance : distanItems) { sumDistan += distance; } distanceU = sumDistan / sampleCount; double deltaPowSum = 0d; double distanceTheta = 0d; // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2)) for (Double distance : distanItems) { deltaPowSum += Math.pow((distance - distanceU), 2); } distanceTheta = Math.sqrt(deltaPowSum); // 剔除噪点可以采用: // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。 double minDistance = distanceU - 0.25 * distanceTheta; double maxDistance = distanceU + 0.25 * distanceTheta; List
willbeRemoveIdxs = new ArrayList
(); for (int i = distanItems.size() - 1; i >= 0; i--) { Double distance = distanItems.get(i); if (distance <= minDistance || distance >= maxDistance) { System.out.println(distance + " out of range [" + minDistance + "," + maxDistance + "]"); willbeRemoveIdxs.add(i); } else { System.out.println(distance); } } for (int willbeRemoveIdx : willbeRemoveIdxs) { Sample sample = samples.get(willbeRemoveIdx); System.out.println("remove " + sample); samples.remove(willbeRemoveIdx); } return samples; } /** * 对采样点进行排序,按照x排序,升序排列 * * @param samples * 采样点集合 **/ private static void sortSample(List
samples) { samples.sort(new Comparator
() { public int compare(Sample o1, Sample o2) { if (o1.getX() > o2.getX()) { return 1; } else if (o1.getX() <= o2.getX()) { return -1; } return 0; } }); } /** * Description 采用最小二乘法多项式拟合方式,获取多项式的系数。 * * @param sampleCount * 采样点个数 * @param fetureCount * 多项式的系数 * @param samples * 采样点集合 **/ private static Matrix leastsequare(int sampleCount, int fetureCout, List
samples) { // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵 Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout); for (int i = 0; i < samples.size(); i++) { matrixX.setAsDouble(samples.get(i).getX(), i, 1); } // System.out.println(matrixX); // System.out.println("--------------------------------------"); // 构件 2*2矩阵 存储X Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1); for (int i = 0; i < samples.size(); i++) { matrixY.setAsDouble(samples.get(i).getY(), i, 0); } // System.out.println(matrixY); // 对X进行转置 Matrix matrixXTrans = matrixX.transpose(); // System.out.println(matrixXTrans); // 乘积运算:x*转转置后x:matrixXTrans*matrixX Matrix matrixMtimes = matrixXTrans.mtimes(matrixX); // System.out.println(matrixMtimes); // System.out.println("--------------------------------------"); // 求逆 Matrix matrixMtimesInv = matrixMtimes.inv(); // System.out.println(matrixMtimesInv); // x转置后结果*求逆结果 // System.out.println("--------------------------------------"); Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans); // System.out.println(matrixMtimesInvMtimes); // System.out.println("--------------------------------------"); Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY); // System.out.println(theta); return theta; }}
View Code

 

你可能感兴趣的文章
Opencv探索之路(十九):读写xml和yml文件
查看>>
Eclipse插件开发中的选择监听机制(Selection Provider-Listener)
查看>>
14.并发与异步 - 2.任务Task -《果壳中的c#》
查看>>
Linux时间子系统之三:jiffies
查看>>
使用 VisualVM 进行性能分析及调优
查看>>
linux升级OpenSSL
查看>>
《QQ欢乐斗地主》山寨版
查看>>
病毒木马查杀实战第015篇:U盘病毒之脱壳研究
查看>>
SDK是什么?什么是SDK
查看>>
centos/linux下的使得maven/tomcat能在普通用户是使用
查看>>
Web学习篇之---html基础知识(一)
查看>>
java多线程入门学习(一)
查看>>
canvas图形处理和进阶用法
查看>>
1. 请问PHP里的ECHO是什么意思 ?请问PHP里的ECHO是什么意思???有什么作用???又应该怎么使用???...
查看>>
ES6,数组遍历
查看>>
如何把浏览器不信任的网址设置为可信任的网点
查看>>
脚本加密http://www.datsi.fi.upm.es/~frosal/sources/
查看>>
Cocos Studio is EOL'd
查看>>
linux shell下16进制 “\uxxxx” unicode to UTF-8中文
查看>>
【WPF】树形结构TreeView的用法(MVVM)
查看>>