12.12
Frequent Itemsets mining 最常见的形式就是给定market-basket形式的数据(每一行相当于一个购物篮,包含多个商品item),然后我们找出关联度大于某个值的所有item集合。A-Priori算法是Frequent Itemsets mining里最基本的算法。
一、A-Priori的基本思路
第一遍(pass 1),扫描文件,统计单项(single item)的出现次数(使用1个map进行统计,key就是item,value就是出现次数)。最后,过滤掉小于最小支持度的,得到频繁单项集。
第二遍(pass 2),扫描文件,对于每一行,对任意两个item组合得到pair item。如果pair item的2个单项都在频繁单项集里,则统计这个pair item的出现次数;否则略过。最后得到所有pair item的出现频率,过滤掉小于最小支持度的,得到频繁2项集。
第N遍,如果发现频繁N-1项集不为空,则说明mining还没有完成,需要进行第n次扫描。从每行中取得任意N个item,如果这N项的所有N-1项子集都在频繁N-1项集里,则统计其出现次数,否则略过。最后,过滤掉小于最小支持度的,得到频繁N项集。
直到频繁项集为空才停止循环。
二、java实现
import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.OutputStreamWriter; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; public class Apriori { private String inputFile; private double minSupport; private BufferedWriter bw; private HashSet<Set<Integer>> frequentItems = new HashSet<Set<Integer>>(); //频繁项集 private Set<Integer> frequentSingleItems = new HashSet<Integer>(); //频繁单项集 private int totalCount; //项目数量 public Apriori(String inputFile, double minSupport) { this.inputFile = inputFile; this.minSupport = minSupport; } /** * 找出频繁一项集 * @return * @throws IOException */ public Map<Set<Integer>,Integer> findF1Item() throws IOException { Map<Set<Integer>,Integer> result = new LinkedHashMap<Set<Integer>, Integer>(); Map<Integer, Integer> map = new HashMap<Integer, Integer>(); BufferedReader reader = new BufferedReader(new FileReader(inputFile)); String line; int numberOfLine = 0; while ((line = reader.readLine()) != null) { numberOfLine++; String[] items = line.split(" "); for(String item : items){ int intItem = Integer.valueOf(item); if (map.containsKey(intItem)) { map.put(intItem, map.get(intItem) + 1); } else { map.put(intItem, 1); } } } totalCount = numberOfLine; reader.close(); //使用TreeSet按照item编号从小到大排序 TreeSet<Integer> treeSet = new TreeSet<Integer>(map.keySet()); for (Integer item : treeSet) { int count = map.get(item); //过滤掉出现频率小于最小支持度的item if (count >= minSupport*totalCount) { Set<Integer> f1Set = new TreeSet<Integer>(); f1Set.add(item); result.put(f1Set, count); frequentItems.add(f1Set); frequentSingleItems.add(item); } } return result; } public Map<Set<Integer>,Integer> generateNextPass(int k) throws Exception{ Map<Set<Integer>, Integer> map = new HashMap<Set<Integer>, Integer>(); BufferedReader reader = new BufferedReader(new FileReader(inputFile)); String line; while ((line = reader.readLine()) != null) { String[] items = line.split(" "); List<Set<Integer>> list = generateSubset(items, k); for(Set<Integer> set :list){ if(map.containsKey(set)){ map.put(set, map.get(set)+1); } else{ map.put(set, 1); } } } reader.close(); frequentItems = new HashSet<Set<Integer>>(); frequentSingleItems = new HashSet<Integer>(); Iterator<Set<Integer>> ite = map.keySet().iterator(); while(ite.hasNext()){ Set<Integer> key = ite.next(); int value = map.get(key); if(value<totalCount*minSupport){ ite.remove(); } else{ frequentItems.add(key); for(int item: key){ frequentSingleItems.add(item); } } } return map; } private List<Set<Integer>> generateSubset(String[] array, int k) throws Exception{ List<Set<Integer>> result = new ArrayList<Set<Integer>>(); int[] intArray = new int[array.length]; for(int i=0; i<array.length;i++){ intArray[i] = Integer.parseInt(array[i]); } int[] newArray = filterItems(intArray); List<Set<Integer>> list = generateSubSets(newArray,k); for(Set<Integer> set: list){ //将set变成数组 int smallArray[] = new int[set.size()]; int i=0; Iterator<Integer> ite = set.iterator(); while(ite.hasNext()){ smallArray[i]= ite.next(); i++; } //找出set的所有k-1次subItemSet List<Set<Integer>> smallList = generateSubSets(smallArray, k-1); //如果有1个subItemSet不是频繁的,则判断set不是频繁的 boolean flag = true; for(Set<Integer> item: smallList){ if(!frequentItems.contains(item)){ flag = false; break; } } if(flag){ result.add(set); } } return result; } public void printFrequentItems(Map<Set<Integer>,Integer> itemSets, int i) throws FileNotFoundException, IOException { if(bw == null){ bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("output.txt"))); } StringBuffer sb = new StringBuffer(); for (Set<Integer> set : itemSets.keySet()) { sb.append("["); for (Integer str : set) { sb.append(str + " "); } sb.append("], count:"); sb.append(itemSets.get(set)); sb.append("\n"); } bw.write(sb.toString()); bw.flush(); } /** * 过滤掉不在前一次频繁单项集里的item * @param array * @return */ private int[] filterItems(int[] array){ List<Integer> list = new ArrayList<Integer>(); for(int i=0;i<array.length;i++){ if(frequentSingleItems.contains(array[i])){ list.add(array[i]); } } int[] newArray = new int[list.size()]; for(int i=0;i<newArray.length;i++){ newArray[i] = list.get(i); } return newArray; } public void closeOutputWriter() throws IOException{ if(bw == null){ bw.close(); } } private List<Set<Integer>> generateSubSets(int[] array, int k){ List<Set<Integer>> list = new ArrayList<Set<Integer>>(); if(array.length<k){ return list; } // 初始化移位法需要的数组 byte[] bits = new byte[array.length]; for (int i = 0; i < bits.length; i++) { bits[i] = i < k ? (byte) 1 : (byte) 0; } boolean find = false; do { // 找到10,换成01 Set<Integer> set = getCombination(array, bits); if(set!=null && set.size()!=0){ list.add(set); } find = false; for (int i = 0; i < array.length - 1; i++) { if (bits[i] == 1 && bits[i+1] == 0) { find = true; bits[i] = 0; bits[i+1] = 1; if(bits[0] == 0){ for (int p=0, q=0; p < i; p++){ if(bits[p] == 1){ byte temp = bits[p]; bits[p] = bits[q]; bits[q] = temp; q++; } } } break; } } } while (find); return list; } private Set<Integer> getCombination(int[] array, byte[] bits) { Set<Integer> set = new TreeSet<Integer>(); for (int i = 0; i < bits.length; i++) { if (bits[i] == (byte) 1) { set.add(array[i]); } } return set; } }
import java.util.Map; import java.util.Set; public class Main { public static void main(String[] args) { System.out.println("program starts…"); long startTime = System.currentTimeMillis(); String inputFile = "src/test.txt"; double minSupport = 0.02; Apriori apriori = new Apriori(inputFile, minSupport); try { System.out.println("pass 1"); Map<Set<Integer>, Integer> f1Set = apriori.findF1Item(); apriori.printFrequentItems(f1Set, 1); Map<Set<Integer>, Integer> result = f1Set; int i = 2; do { System.out.println("pass " + i); result = apriori.generateNextPass(i); apriori.printFrequentItems(result, i); i++; } while (result.size() != 0); apriori.closeOutputWriter(); } catch (Exception e) { e.printStackTrace(); } long endTime = System.currentTimeMillis(); System.out.println("execution time:" + (endTime - startTime) + "ms"); } }
输入文件:test.txt
1 2 5
2 4
2 3
1 2 4
1 3
2 3
1 3
1 2 3 5
1 2 3
输出文件:output.txt
[1 ], count:6
[2 ], count:7
[3 ], count:6
[4 ], count:2
[5 ], count:2
[1 2 ], count:4
[1 3 ], count:4
[1 4 ], count:1
[2 3 ], count:4
[2 4 ], count:2
[1 5 ], count:2
[2 5 ], count:2
[3 5 ], count:1
[1 2 3 ], count:2
[1 2 4 ], count:1
[1 2 5 ], count:2
[1 3 5 ], count:1
[2 3 5 ], count:1
[1 2 3 5 ], count:1