2013
12.12

Frequent Itemsets mining 最常见的形式就是给定market-basket形式的数据(每一行相当于一个购物篮,包含多个商品item),然后我们找出关联度大于某个值的所有item集合。A-Priori算法是Frequent Itemsets mining里最基本的算法。

一、A-Priori的基本思路

第一遍(pass 1),扫描文件,统计单项(single item)的出现次数(使用1个map进行统计,key就是item,value就是出现次数)。最后,过滤掉小于最小支持度的,得到频繁单项集。

第二遍(pass 2),扫描文件,对于每一行,对任意两个item组合得到pair item。如果pair item的2个单项都在频繁单项集里,则统计这个pair item的出现次数;否则略过。最后得到所有pair item的出现频率,过滤掉小于最小支持度的,得到频繁2项集。

第N遍,如果发现频繁N-1项集不为空,则说明mining还没有完成,需要进行第n次扫描。从每行中取得任意N个item,如果这N项的所有N-1项子集都在频繁N-1项集里,则统计其出现次数,否则略过。最后,过滤掉小于最小支持度的,得到频繁N项集。

直到频繁项集为空才停止循环。

二、java实现

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

public class Apriori {

	private String inputFile;

	private double minSupport;

	private BufferedWriter bw;

	private HashSet<Set<Integer>> frequentItems = new HashSet<Set<Integer>>(); //频繁项集

	private Set<Integer> frequentSingleItems = new HashSet<Integer>(); //频繁单项集

	private int totalCount; //项目数量

	public Apriori(String inputFile, double minSupport) {
		this.inputFile = inputFile;
		this.minSupport = minSupport;
	}

	/**
	 * 找出频繁一项集
	 * @return
	 * @throws IOException
	 */
	public Map<Set<Integer>,Integer> findF1Item() throws IOException {
		Map<Set<Integer>,Integer> result = new LinkedHashMap<Set<Integer>, Integer>();
		Map<Integer, Integer> map = new HashMap<Integer, Integer>();
		BufferedReader reader = new BufferedReader(new FileReader(inputFile));
		String line;
		int numberOfLine = 0;
		while ((line = reader.readLine()) != null) {
			numberOfLine++;
			String[] items = line.split(" ");
			for(String item : items){
				int intItem = Integer.valueOf(item);
				if (map.containsKey(intItem)) {
					map.put(intItem, map.get(intItem) + 1);
				} else {
					map.put(intItem, 1);
				}
			}
		}
		totalCount = numberOfLine;
		reader.close();
		//使用TreeSet按照item编号从小到大排序
		TreeSet<Integer> treeSet = new TreeSet<Integer>(map.keySet());
		for (Integer item : treeSet) {
			int count = map.get(item);
			//过滤掉出现频率小于最小支持度的item
			if (count >= minSupport*totalCount) {
				Set<Integer> f1Set = new TreeSet<Integer>();
				f1Set.add(item);
				result.put(f1Set, count);
				frequentItems.add(f1Set);
				frequentSingleItems.add(item);
			}
		}
		return result;
	}

	public Map<Set<Integer>,Integer> generateNextPass(int k) throws Exception{
		Map<Set<Integer>, Integer> map = new HashMap<Set<Integer>, Integer>();
		BufferedReader reader = new BufferedReader(new FileReader(inputFile));
		String line;
		while ((line = reader.readLine()) != null) {
			String[] items = line.split(" ");
			List<Set<Integer>> list = generateSubset(items, k);
			for(Set<Integer> set :list){
				if(map.containsKey(set)){
					map.put(set, map.get(set)+1);
				}
				else{
					map.put(set, 1);
				}
			}
		}
		reader.close();
		frequentItems = new HashSet<Set<Integer>>();
		frequentSingleItems = new HashSet<Integer>();
		Iterator<Set<Integer>> ite = map.keySet().iterator();
		while(ite.hasNext()){
			Set<Integer> key = ite.next();
			int value = map.get(key);
			if(value<totalCount*minSupport){
				ite.remove();
			}
			else{
				frequentItems.add(key);
				for(int item: key){
					frequentSingleItems.add(item);
				}
			}
		}
		return map;
	}

	private List<Set<Integer>> generateSubset(String[] array, int k) throws Exception{
		List<Set<Integer>> result = new ArrayList<Set<Integer>>();
		int[] intArray = new int[array.length];
		for(int i=0; i<array.length;i++){
			intArray[i] = Integer.parseInt(array[i]);
		}
		int[] newArray = filterItems(intArray);
		List<Set<Integer>> list = generateSubSets(newArray,k);
		for(Set<Integer> set: list){
			//将set变成数组
			int smallArray[] = new int[set.size()];
			int i=0;
			Iterator<Integer> ite = set.iterator();
			while(ite.hasNext()){
				smallArray[i]= ite.next();
				i++;
			}
			//找出set的所有k-1次subItemSet
			List<Set<Integer>> smallList = generateSubSets(smallArray, k-1);
			//如果有1个subItemSet不是频繁的,则判断set不是频繁的
			boolean flag = true;
			for(Set<Integer> item: smallList){
				if(!frequentItems.contains(item)){
					flag = false;
					break;
				}
			}
			if(flag){
				result.add(set);
			}
		}
		return result;
	}

	public void printFrequentItems(Map<Set<Integer>,Integer> itemSets, int i) throws FileNotFoundException, IOException {
		if(bw == null){
			bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("output.txt")));
		}
		StringBuffer sb = new StringBuffer();
		for (Set<Integer> set : itemSets.keySet()) {
			sb.append("[");
			for (Integer str : set) {
				sb.append(str + " ");
			}
			sb.append("], count:");
					sb.append(itemSets.get(set));
					sb.append("\n");
		}
		bw.write(sb.toString());
		bw.flush();
	}

	/**
	* 过滤掉不在前一次频繁单项集里的item
	* @param array
	* @return
	*/
	private int[] filterItems(int[] array){
		List<Integer> list = new ArrayList<Integer>();
		for(int i=0;i<array.length;i++){
			if(frequentSingleItems.contains(array[i])){
				list.add(array[i]);
			}
		}
		int[] newArray = new int[list.size()];
		for(int i=0;i<newArray.length;i++){
			newArray[i] = list.get(i);
		}
		return newArray;
	}

	public void closeOutputWriter() throws IOException{
		if(bw == null){
			bw.close();
		}
	}

	private List<Set<Integer>> generateSubSets(int[] array, int k){
		List<Set<Integer>> list = new ArrayList<Set<Integer>>();
		if(array.length<k){
			return list;
		}
		// 初始化移位法需要的数组
		byte[] bits = new byte[array.length];
		for (int i = 0; i < bits.length; i++) {
			bits[i] = i < k ? (byte) 1 : (byte) 0;
		}
		boolean find = false;
		do {
			// 找到10,换成01
			Set<Integer> set = getCombination(array, bits);
			if(set!=null && set.size()!=0){
				list.add(set);
			}
			find = false;
			for (int i = 0; i < array.length - 1; i++) {
				if (bits[i] == 1 && bits[i+1] == 0) {
					find = true;
					bits[i] = 0;
					bits[i+1] = 1;
					if(bits[0] == 0){
						for (int p=0, q=0; p < i; p++){
							if(bits[p] == 1){
								byte temp = bits[p];
								bits[p] = bits[q];
								bits[q] = temp;
								q++;
							}
						}
					}
					break;
				}
			}

		} while (find);
		return list;
	}

	private Set<Integer> getCombination(int[] array, byte[] bits) {
		Set<Integer> set = new TreeSet<Integer>();
		for (int i = 0; i < bits.length; i++) {
			if (bits[i] == (byte) 1) {
				set.add(array[i]);
			}
		}
		return set;
	}

}



import java.util.Map;
import java.util.Set;

public class Main {

	public static void main(String[] args) {
		System.out.println("program starts…");
		long startTime = System.currentTimeMillis();
		String inputFile = "src/test.txt";
		double minSupport = 0.02;
		Apriori apriori = new Apriori(inputFile, minSupport);
		try {
			System.out.println("pass 1");
			Map<Set<Integer>, Integer> f1Set = apriori.findF1Item();
			apriori.printFrequentItems(f1Set, 1);
			Map<Set<Integer>, Integer> result = f1Set;
			int i = 2;
			do {
				System.out.println("pass " + i);
				result = apriori.generateNextPass(i);
				apriori.printFrequentItems(result, i);
				i++;
			} while (result.size() != 0);
			apriori.closeOutputWriter();
		} catch (Exception e) {
			e.printStackTrace();
		}
		long endTime = System.currentTimeMillis();
		System.out.println("execution time:" + (endTime - startTime) + "ms");
	}
}

输入文件:test.txt

1 2 5
2 4
2 3
1 2 4
1 3
2 3
1 3
1 2 3 5
1 2 3

输出文件:output.txt

[1 ], count:6
[2 ], count:7
[3 ], count:6
[4 ], count:2
[5 ], count:2
[1 2 ], count:4
[1 3 ], count:4
[1 4 ], count:1
[2 3 ], count:4
[2 4 ], count:2
[1 5 ], count:2
[2 5 ], count:2
[3 5 ], count:1
[1 2 3 ], count:2
[1 2 4 ], count:1
[1 2 5 ], count:2
[1 3 5 ], count:1
[2 3 5 ], count:1
[1 2 3 5 ], count:1

回复功能关闭


Hit Counter by http://yizhantech.com/