Frequent Itemsets mining 最常见的形式就是给定market-basket形式的数据(每一行相当于一个购物篮,包含多个商品item),然后我们找出关联度大于某个值的所有item集合。A-Priori算法是Frequent Itemsets mining里最基本的算法。
一、A-Priori的基本思路
第一遍(pass 1),扫描文件,统计单项(single item)的出现次数(使用1个map进行统计,key就是item,value就是出现次数)。最后,过滤掉小于最小支持度的,得到频繁单项集。
第二遍(pass 2),扫描文件,对于每一行,对任意两个item组合得到pair item。如果pair item的2个单项都在频繁单项集里,则统计这个pair item的出现次数;否则略过。最后得到所有pair item的出现频率,过滤掉小于最小支持度的,得到频繁2项集。
第N遍,如果发现频繁N-1项集不为空,则说明mining还没有完成,需要进行第n次扫描。从每行中取得任意N个item,如果这N项的所有N-1项子集都在频繁N-1项集里,则统计其出现次数,否则略过。最后,过滤掉小于最小支持度的,得到频繁N项集。
直到频繁项集为空才停止循环。
二、java实现
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
public class Apriori {
private String inputFile;
private double minSupport;
private BufferedWriter bw;
private HashSet<Set<Integer>> frequentItems = new HashSet<Set<Integer>>(); //频繁项集
private Set<Integer> frequentSingleItems = new HashSet<Integer>(); //频繁单项集
private int totalCount; //项目数量
public Apriori(String inputFile, double minSupport) {
this.inputFile = inputFile;
this.minSupport = minSupport;
}
/**
* 找出频繁一项集
* @return
* @throws IOException
*/
public Map<Set<Integer>,Integer> findF1Item() throws IOException {
Map<Set<Integer>,Integer> result = new LinkedHashMap<Set<Integer>, Integer>();
Map<Integer, Integer> map = new HashMap<Integer, Integer>();
BufferedReader reader = new BufferedReader(new FileReader(inputFile));
String line;
int numberOfLine = 0;
while ((line = reader.readLine()) != null) {
numberOfLine++;
String[] items = line.split(" ");
for(String item : items){
int intItem = Integer.valueOf(item);
if (map.containsKey(intItem)) {
map.put(intItem, map.get(intItem) + 1);
} else {
map.put(intItem, 1);
}
}
}
totalCount = numberOfLine;
reader.close();
//使用TreeSet按照item编号从小到大排序
TreeSet<Integer> treeSet = new TreeSet<Integer>(map.keySet());
for (Integer item : treeSet) {
int count = map.get(item);
//过滤掉出现频率小于最小支持度的item
if (count >= minSupport*totalCount) {
Set<Integer> f1Set = new TreeSet<Integer>();
f1Set.add(item);
result.put(f1Set, count);
frequentItems.add(f1Set);
frequentSingleItems.add(item);
}
}
return result;
}
public Map<Set<Integer>,Integer> generateNextPass(int k) throws Exception{
Map<Set<Integer>, Integer> map = new HashMap<Set<Integer>, Integer>();
BufferedReader reader = new BufferedReader(new FileReader(inputFile));
String line;
while ((line = reader.readLine()) != null) {
String[] items = line.split(" ");
List<Set<Integer>> list = generateSubset(items, k);
for(Set<Integer> set :list){
if(map.containsKey(set)){
map.put(set, map.get(set)+1);
}
else{
map.put(set, 1);
}
}
}
reader.close();
frequentItems = new HashSet<Set<Integer>>();
frequentSingleItems = new HashSet<Integer>();
Iterator<Set<Integer>> ite = map.keySet().iterator();
while(ite.hasNext()){
Set<Integer> key = ite.next();
int value = map.get(key);
if(value<totalCount*minSupport){
ite.remove();
}
else{
frequentItems.add(key);
for(int item: key){
frequentSingleItems.add(item);
}
}
}
return map;
}
private List<Set<Integer>> generateSubset(String[] array, int k) throws Exception{
List<Set<Integer>> result = new ArrayList<Set<Integer>>();
int[] intArray = new int[array.length];
for(int i=0; i<array.length;i++){
intArray[i] = Integer.parseInt(array[i]);
}
int[] newArray = filterItems(intArray);
List<Set<Integer>> list = generateSubSets(newArray,k);
for(Set<Integer> set: list){
//将set变成数组
int smallArray[] = new int[set.size()];
int i=0;
Iterator<Integer> ite = set.iterator();
while(ite.hasNext()){
smallArray[i]= ite.next();
i++;
}
//找出set的所有k-1次subItemSet
List<Set<Integer>> smallList = generateSubSets(smallArray, k-1);
//如果有1个subItemSet不是频繁的,则判断set不是频繁的
boolean flag = true;
for(Set<Integer> item: smallList){
if(!frequentItems.contains(item)){
flag = false;
break;
}
}
if(flag){
result.add(set);
}
}
return result;
}
public void printFrequentItems(Map<Set<Integer>,Integer> itemSets, int i) throws FileNotFoundException, IOException {
if(bw == null){
bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("output.txt")));
}
StringBuffer sb = new StringBuffer();
for (Set<Integer> set : itemSets.keySet()) {
sb.append("[");
for (Integer str : set) {
sb.append(str + " ");
}
sb.append("], count:");
sb.append(itemSets.get(set));
sb.append("\n");
}
bw.write(sb.toString());
bw.flush();
}
/**
* 过滤掉不在前一次频繁单项集里的item
* @param array
* @return
*/
private int[] filterItems(int[] array){
List<Integer> list = new ArrayList<Integer>();
for(int i=0;i<array.length;i++){
if(frequentSingleItems.contains(array[i])){
list.add(array[i]);
}
}
int[] newArray = new int[list.size()];
for(int i=0;i<newArray.length;i++){
newArray[i] = list.get(i);
}
return newArray;
}
public void closeOutputWriter() throws IOException{
if(bw == null){
bw.close();
}
}
private List<Set<Integer>> generateSubSets(int[] array, int k){
List<Set<Integer>> list = new ArrayList<Set<Integer>>();
if(array.length<k){
return list;
}
// 初始化移位法需要的数组
byte[] bits = new byte[array.length];
for (int i = 0; i < bits.length; i++) {
bits[i] = i < k ? (byte) 1 : (byte) 0;
}
boolean find = false;
do {
// 找到10,换成01
Set<Integer> set = getCombination(array, bits);
if(set!=null && set.size()!=0){
list.add(set);
}
find = false;
for (int i = 0; i < array.length - 1; i++) {
if (bits[i] == 1 && bits[i+1] == 0) {
find = true;
bits[i] = 0;
bits[i+1] = 1;
if(bits[0] == 0){
for (int p=0, q=0; p < i; p++){
if(bits[p] == 1){
byte temp = bits[p];
bits[p] = bits[q];
bits[q] = temp;
q++;
}
}
}
break;
}
}
} while (find);
return list;
}
private Set<Integer> getCombination(int[] array, byte[] bits) {
Set<Integer> set = new TreeSet<Integer>();
for (int i = 0; i < bits.length; i++) {
if (bits[i] == (byte) 1) {
set.add(array[i]);
}
}
return set;
}
}
import java.util.Map;
import java.util.Set;
public class Main {
public static void main(String[] args) {
System.out.println("program starts…");
long startTime = System.currentTimeMillis();
String inputFile = "src/test.txt";
double minSupport = 0.02;
Apriori apriori = new Apriori(inputFile, minSupport);
try {
System.out.println("pass 1");
Map<Set<Integer>, Integer> f1Set = apriori.findF1Item();
apriori.printFrequentItems(f1Set, 1);
Map<Set<Integer>, Integer> result = f1Set;
int i = 2;
do {
System.out.println("pass " + i);
result = apriori.generateNextPass(i);
apriori.printFrequentItems(result, i);
i++;
} while (result.size() != 0);
apriori.closeOutputWriter();
} catch (Exception e) {
e.printStackTrace();
}
long endTime = System.currentTimeMillis();
System.out.println("execution time:" + (endTime - startTime) + "ms");
}
}
输入文件:test.txt
1 2 5
2 4
2 3
1 2 4
1 3
2 3
1 3
1 2 3 5
1 2 3
输出文件:output.txt
[1 ], count:6
[2 ], count:7
[3 ], count:6
[4 ], count:2
[5 ], count:2
[1 2 ], count:4
[1 3 ], count:4
[1 4 ], count:1
[2 3 ], count:4
[2 4 ], count:2
[1 5 ], count:2
[2 5 ], count:2
[3 5 ], count:1
[1 2 3 ], count:2
[1 2 4 ], count:1
[1 2 5 ], count:2
[1 3 5 ], count:1
[2 3 5 ], count:1
[1 2 3 5 ], count:1