改变kNN算法中k的值-Java

Altering the value of k in kNN algorithm - Java

我已应用 KNN 算法对手写数字进行分类。数字最初是 8*8 的矢量格式,然后拉伸形成一个 1*64 的矢量。

就目前而言,我的代码应用了 kNN 算法,但只使用了 k = 1。在尝试了几件事后,我不完全确定如何更改 k 值,但我一直在抛出错误。如果有人能帮助我朝着正确的方向前进,我将不胜感激。训练数据集可以在这里找到,验证集在这里。

ImageMatrix.java

import java.util.*;



public class ImageMatrix {

  private int[] data;

  private int classCode;

  private int curData;

public ImageMatrix(int[] data, int classCode) {

  assert data.length == 64; //maximum array length of 64

  this.data = data;

  this.classCode = classCode;

}



  public String toString() {

    return"Class Code:" + classCode +" Data :" + Arrays.toString(data) +"\

"; //outputs readable

  }



  public int[] getData() {

    return data;

  }



  public int getClassCode() {

    return classCode;

  }

  public int getCurData() {

    return curData;

  }







}

import java.util.*;

import java.io.*;

import java.util.ArrayList;

public class ImageMatrixDB implements Iterable<ImageMatrix> {

  private List<ImageMatrix> list = new ArrayList<ImageMatrix>();



  public ImageMatrixDB load(String f) throws IOException {

    try (

      FileReader fr = new FileReader(f);

      BufferedReader br = new BufferedReader(fr)) {

      String line = null;



      while((line = br.readLine()) != null) {

        int lastComma = line.lastIndexOf(',');

        int classCode = Integer.parseInt(line.substring(1 + lastComma));

        int[] data = Arrays.stream(line.substring(0, lastComma).split(","))

                 .mapToInt(Integer::parseInt)

                 .toArray();

        ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..

        list.add(matrix);

      }

    }

    return this;

  }



  public void printResults(){ //output results 

    for(ImageMatrix matrix: list){

      System.out.println(matrix);

    }

  }





  public Iterator<ImageMatrix> iterator() {

    return this.list.iterator();

  }



  /// kNN implementation ///

  public static int distance(int[] a, int[] b) {

    int sum = 0;

    for(int i = 0; i < a.length; i++) {

      sum += (a[i] - b[i]) * (a[i] - b[i]);

    }

    return (int)Math.sqrt(sum);

  }





  public static int classify(ImageMatrixDB trainingSet, int[] curData) {

    int label = 0, bestDistance = Integer.MAX_VALUE;

    for(ImageMatrix matrix: trainingSet) {

      int dist = distance(matrix.getData(), curData);

      if(dist < bestDistance) {

        bestDistance = dist;

        label = matrix.getClassCode();

      }

    }

    return label;

  }





  public int size() {



    return list.size(); //returns size of the list



    }





  public static void main(String[] argv) throws IOException {

    ImageMatrixDB trainingSet = new ImageMatrixDB();

    ImageMatrixDB validationSet = new ImageMatrixDB();

    trainingSet.load("cw2DataSet1.csv");

    validationSet.load("cw2DataSet2.csv"); 

    int numCorrect = 0;

    for(ImageMatrix matrix:validationSet) {

      if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;

    } //285 correct

    System.out.println("Accuracy:" + (double)numCorrect / validationSet.size() * 100 +"%"); 

    System.out.println();

  }

public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {

  int label = 0, bestDistance = Integer.MAX_VALUE;

  int[][] distances = new int[trainingSet.size()][2];

  int i=0;



  // Place distances in an array to be sorted

  for(ImageMatrix matrix: trainingSet) {

    distances[i][0] = distance(matrix.getData(), curData);

    distances[i][1] = matrix.getClassCode();

    i++;

  }



  Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);



  // Find frequencies of each class code

  i = 0;

  Map<Integer,Integer> majorityMap;

  majorityMap = new HashMap<Integer,Integer>();

  while(i < k) {

    if( majorityMap.containsKey( distances[i][1] ) ) {

      int currentValue = majorityMap.get(distances[i][1]);

      majorityMap.put(distances[i][1], currentValue + 1);

    }

    else {

      majorityMap.put(distances[i][1], 1);

    }

    ++i;

  }



  // Find the class code with the highest frequency

  int maxVal = -1;

  for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {

    int entryVal = entry.getValue();

    if(entryVal > maxVal) {

      maxVal = entryVal;

      label = entry.getKey();

    }

  }



  return label;

}

ImageMatrixDB.java

import java.util.*;



public class ImageMatrix {

  private int[] data;

  private int classCode;

  private int curData;

public ImageMatrix(int[] data, int classCode) {

  assert data.length == 64; //maximum array length of 64

  this.data = data;

  this.classCode = classCode;

}



  public String toString() {

    return"Class Code:" + classCode +" Data :" + Arrays.toString(data) +"\

"; //outputs readable

  }



  public int[] getData() {

    return data;

  }



  public int getClassCode() {

    return classCode;

  }

  public int getCurData() {

    return curData;

  }







}

import java.util.*;

import java.io.*;

import java.util.ArrayList;

public class ImageMatrixDB implements Iterable<ImageMatrix> {

  private List<ImageMatrix> list = new ArrayList<ImageMatrix>();



  public ImageMatrixDB load(String f) throws IOException {

    try (

      FileReader fr = new FileReader(f);

      BufferedReader br = new BufferedReader(fr)) {

      String line = null;



      while((line = br.readLine()) != null) {

        int lastComma = line.lastIndexOf(',');

        int classCode = Integer.parseInt(line.substring(1 + lastComma));

        int[] data = Arrays.stream(line.substring(0, lastComma).split(","))

                 .mapToInt(Integer::parseInt)

                 .toArray();

        ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..

        list.add(matrix);

      }

    }

    return this;

  }



  public void printResults(){ //output results 

    for(ImageMatrix matrix: list){

      System.out.println(matrix);

    }

  }





  public Iterator<ImageMatrix> iterator() {

    return this.list.iterator();

  }



  /// kNN implementation ///

  public static int distance(int[] a, int[] b) {

    int sum = 0;

    for(int i = 0; i < a.length; i++) {

      sum += (a[i] - b[i]) * (a[i] - b[i]);

    }

    return (int)Math.sqrt(sum);

  }





  public static int classify(ImageMatrixDB trainingSet, int[] curData) {

    int label = 0, bestDistance = Integer.MAX_VALUE;

    for(ImageMatrix matrix: trainingSet) {

      int dist = distance(matrix.getData(), curData);

      if(dist < bestDistance) {

        bestDistance = dist;

        label = matrix.getClassCode();

      }

    }

    return label;

  }





  public int size() {



    return list.size(); //returns size of the list



    }





  public static void main(String[] argv) throws IOException {

    ImageMatrixDB trainingSet = new ImageMatrixDB();

    ImageMatrixDB validationSet = new ImageMatrixDB();

    trainingSet.load("cw2DataSet1.csv");

    validationSet.load("cw2DataSet2.csv"); 

    int numCorrect = 0;

    for(ImageMatrix matrix:validationSet) {

      if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;

    } //285 correct

    System.out.println("Accuracy:" + (double)numCorrect / validationSet.size() * 100 +"%"); 

    System.out.println();

  }

public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {

  int label = 0, bestDistance = Integer.MAX_VALUE;

  int[][] distances = new int[trainingSet.size()][2];

  int i=0;



  // Place distances in an array to be sorted

  for(ImageMatrix matrix: trainingSet) {

    distances[i][0] = distance(matrix.getData(), curData);

    distances[i][1] = matrix.getClassCode();

    i++;

  }



  Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);



  // Find frequencies of each class code

  i = 0;

  Map<Integer,Integer> majorityMap;

  majorityMap = new HashMap<Integer,Integer>();

  while(i < k) {

    if( majorityMap.containsKey( distances[i][1] ) ) {

      int currentValue = majorityMap.get(distances[i][1]);

      majorityMap.put(distances[i][1], currentValue + 1);

    }

    else {

      majorityMap.put(distances[i][1], 1);

    }

    ++i;

  }



  // Find the class code with the highest frequency

  int maxVal = -1;

  for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {

    int entryVal = entry.getValue();

    if(entryVal > maxVal) {

      maxVal = entryVal;

      label = entry.getKey();

    }

  }



  return label;

}

在分类的 for 循环中,您试图找到最接近测试点的训练示例。您需要使用找到最接近测试数据的 K 个训练点的代码来切换它。然后你应该为这些 K 点中的每一个调用 getClassCode 并找到其中大多数(即最频繁)的类代码。然后,分类将返回您找到的主要类代码。

您可以以任何适合您需要的方式打破联系(即,将 2 个最常见的类代码分配给相同数量的训练数据)。

我在Java方面真的很缺乏经验,但是只是通过查看语言参考,我想出了下面的实现。

import java.util.*;



public class ImageMatrix {

  private int[] data;

  private int classCode;

  private int curData;

public ImageMatrix(int[] data, int classCode) {

  assert data.length == 64; //maximum array length of 64

  this.data = data;

  this.classCode = classCode;

}



  public String toString() {

    return"Class Code:" + classCode +" Data :" + Arrays.toString(data) +"\

"; //outputs readable

  }



  public int[] getData() {

    return data;

  }



  public int getClassCode() {

    return classCode;

  }

  public int getCurData() {

    return curData;

  }







}

import java.util.*;

import java.io.*;

import java.util.ArrayList;

public class ImageMatrixDB implements Iterable<ImageMatrix> {

  private List<ImageMatrix> list = new ArrayList<ImageMatrix>();



  public ImageMatrixDB load(String f) throws IOException {

    try (

      FileReader fr = new FileReader(f);

      BufferedReader br = new BufferedReader(fr)) {

      String line = null;



      while((line = br.readLine()) != null) {

        int lastComma = line.lastIndexOf(',');

        int classCode = Integer.parseInt(line.substring(1 + lastComma));

        int[] data = Arrays.stream(line.substring(0, lastComma).split(","))

                 .mapToInt(Integer::parseInt)

                 .toArray();

        ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..

        list.add(matrix);

      }

    }

    return this;

  }



  public void printResults(){ //output results 

    for(ImageMatrix matrix: list){

      System.out.println(matrix);

    }

  }





  public Iterator<ImageMatrix> iterator() {

    return this.list.iterator();

  }



  /// kNN implementation ///

  public static int distance(int[] a, int[] b) {

    int sum = 0;

    for(int i = 0; i < a.length; i++) {

      sum += (a[i] - b[i]) * (a[i] - b[i]);

    }

    return (int)Math.sqrt(sum);

  }





  public static int classify(ImageMatrixDB trainingSet, int[] curData) {

    int label = 0, bestDistance = Integer.MAX_VALUE;

    for(ImageMatrix matrix: trainingSet) {

      int dist = distance(matrix.getData(), curData);

      if(dist < bestDistance) {

        bestDistance = dist;

        label = matrix.getClassCode();

      }

    }

    return label;

  }





  public int size() {



    return list.size(); //returns size of the list



    }





  public static void main(String[] argv) throws IOException {

    ImageMatrixDB trainingSet = new ImageMatrixDB();

    ImageMatrixDB validationSet = new ImageMatrixDB();

    trainingSet.load("cw2DataSet1.csv");

    validationSet.load("cw2DataSet2.csv"); 

    int numCorrect = 0;

    for(ImageMatrix matrix:validationSet) {

      if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;

    } //285 correct

    System.out.println("Accuracy:" + (double)numCorrect / validationSet.size() * 100 +"%"); 

    System.out.println();

  }

public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {

  int label = 0, bestDistance = Integer.MAX_VALUE;

  int[][] distances = new int[trainingSet.size()][2];

  int i=0;



  // Place distances in an array to be sorted

  for(ImageMatrix matrix: trainingSet) {

    distances[i][0] = distance(matrix.getData(), curData);

    distances[i][1] = matrix.getClassCode();

    i++;

  }



  Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);



  // Find frequencies of each class code

  i = 0;

  Map<Integer,Integer> majorityMap;

  majorityMap = new HashMap<Integer,Integer>();

  while(i < k) {

    if( majorityMap.containsKey( distances[i][1] ) ) {

      int currentValue = majorityMap.get(distances[i][1]);

      majorityMap.put(distances[i][1], currentValue + 1);

    }

    else {

      majorityMap.put(distances[i][1], 1);

    }

    ++i;

  }



  // Find the class code with the highest frequency

  int maxVal = -1;

  for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {

    int entryVal = entry.getValue();

    if(entryVal > maxVal) {

      maxVal = entryVal;

      label = entry.getKey();

    }

  }



  return label;

}

您需要做的就是添加 K 作为参数。但是请记住,上面的代码并没有以特定方式处理关系。


相关推荐

  • Spring部署设置openshift

    Springdeploymentsettingsopenshift我有一个问题让我抓狂了三天。我根据OpenShift帐户上的教程部署了spring-eap6-quickstart代码。我已配置调试选项,并且已将Eclipse工作区与OpehShift服务器同步-服务器上的一切工作正常,但在Eclipse中出现无法消除的错误。我有这个错误:cvc-complex-type.2.4.a:Invali…
    2025-04-161
  • 检查Java中正则表达式中模式的第n次出现

    CheckfornthoccurrenceofpatterninregularexpressioninJava本问题已经有最佳答案,请猛点这里访问。我想使用Java正则表达式检查输入字符串中特定模式的第n次出现。你能建议怎么做吗?这应该可以工作:MatchResultfindNthOccurance(intn,Patternp,CharSequencesrc){Matcherm=p.matcher…
    2025-04-161
  • 如何让 JTable 停留在已编辑的单元格上

    HowtohaveJTablestayingontheeditedcell如果有人编辑JTable的单元格内容并按Enter,则内容会被修改并且表格选择会移动到下一行。是否可以禁止JTable在单元格编辑后转到下一行?原因是我的程序使用ListSelectionListener在单元格选择上同步了其他一些小部件,并且我不想在编辑当前单元格后选择下一行。Enter的默认绑定是名为selectNext…
    2025-04-161
  • Weblogic 12c 部署

    Weblogic12cdeploy我正在尝试将我的应用程序从Tomcat迁移到Weblogic12.2.1.3.0。我能够毫无错误地部署应用程序,但我遇到了与持久性提供程序相关的运行时错误。这是堆栈跟踪:javax.validation.ValidationException:CalltoTraversableResolver.isReachable()threwanexceptionatorg.…
    2025-04-161
  • Resteasy Content-Type 默认值

    ResteasyContent-Typedefaults我正在使用Resteasy编写一个可以返回JSON和XML的应用程序,但可以选择默认为XML。这是我的方法:@GET@Path("/content")@Produces({MediaType.APPLICATION_XML,MediaType.APPLICATION_JSON})publicStringcontentListRequestXm…
    2025-04-161
  • 代码不会停止运行,在 Java 中

    thecodedoesn'tstoprunning,inJava我正在用Java解决项目Euler中的问题10,即"Thesumoftheprimesbelow10is2+3+5+7=17.Findthesumofalltheprimesbelowtwomillion."我的代码是packageprojecteuler_1;importjava.math.BigInteger;importjava…
    2025-04-161
  • Out of memory java heap space

    Outofmemoryjavaheapspace我正在尝试将大量文件从服务器发送到多个客户端。当我尝试发送大小为700mb的文件时,它显示了"OutOfMemoryjavaheapspace"错误。我正在使用Netbeans7.1.2版本。我还在属性中尝试了VMoption。但仍然发生同样的错误。我认为阅读整个文件存在一些问题。下面的代码最多可用于300mb。请给我一些建议。提前致谢publicc…
    2025-04-161
  • Log4j 记录到共享日志文件

    Log4jLoggingtoaSharedLogFile有没有办法将log4j日志记录事件写入也被其他应用程序写入的日志文件。其他应用程序可以是非Java应用程序。有什么缺点?锁定问题?格式化?Log4j有一个SocketAppender,它将向服务发送事件,您可以自己实现或使用与Log4j捆绑的简单实现。它还支持syslogd和Windows事件日志,这对于尝试将日志输出与来自非Java应用程序…
    2025-04-161