2008-05-10

Erlangでベクトルの集合をクラスタリング(k-means)

Erlang本を買ったので、k-means法によるクラスタリングをサンプル的に作ってみた。
Erlang簡潔でいいですね。

使い方
データファイルをタブ区切りで用意する。
1カラム目はデータの識別用ID文字列、2カラム目以降にベクトルの数値データを記入。改行コードは¥nで。

ayu@~/work% /usr/local/erlang_R12B_2/bin/erl
Erlang (BEAM) emulator version 5.6.2 [source] [smp:2] [async-threads:0] [kernel-poll:false]
Eshell V5.6.2 (abort with ^G)
1> c(kmeans).
{ok,kmeans}
2> Src = kmeans:read_line("sampledata2.txt").
[{"1",[1,9,2,4,4,163.5,54,3,1,4,4,2.5,1,1,1,4,3,3,3,2,5]},
{"2",[2,7,1,5,5,154.2,42,2,2,4,3,24.5,1,1,4,3,4,3,5,2,4]},
{"3",[2,8,2,5,3,153.8,44.8,3.5,1,2,5,24,1,1,5,3,2,5,4,3,2]},
.....
3> {L, R} = kmeans:splitvectorset(Src).
{[{"210",[2,4,2,3,3,156,42,3,2,3,4,2,2,2,3,2,3,5,4,1,3]},
{"205",[1,6,1,3,5,170,54,2.5,2,3,3,1.5,1,1,4,3,4,1,2,3,5]},
{"203",[1,2,1,2,3,163,53,2,2,4,4,24,2,1,3,2,2,1,3,1,3]},
....
4>

以上で、LとRに二分割された集合が格納される。

ソース
-module(kmeans).
-import(lists, [map/2, zip/2, sum/1]).
-import(math, [sqrt/1]).
-export([splitvectorset/1, read_line/1]).

%ベクトルのノルム
norm(X)->
sqrt(sum(map(fun(Y)->Y*Y end, X))).

%ベクトル集合の重心
median([],_)->
[];
median(X, L)->
First = lists:nth(1,X),
if
First == [] -> [];
true -> [sum(map(fun erlang:hd/1, X))/L | median(map(fun erlang:tl/1, X), L)]
end.
median(X)->
median(X, length(X)).

%ベクトルの差
sub([], _)->
[];
sub(X, Y) ->
[hd(X)-hd(Y)|sub(tl(X), tl(Y))].

%二つのベクトル間の距離
distance(X, Y) ->
norm(sub(X,Y)).

%与えられた二点X,YのどちらにベクトルZは近いか?
% Xに近ければ'Right', Yに近ければ'Left'を返す。
which(_,_,[])->
ok;
which(X, Y, Z) ->
DX = distance(X,Z),
DY = distance(Y,Z),
if
DX > DY -> 'Left';
true -> 'Right'
end.

%[{'Right/Left', ベクトル}, ...]というリストを'Right'/'Left'に従って二つの
%リストに分けて返す。
splitvector([], R, L)->
{R, L};
splitvector([{D,Z}|T],R,L) ->
case D of
'Right' -> splitvector(T, [Z|R], L);
'Left' -> splitvector(T, R, [Z|L])
end.

%集合ZSの2元からなる部分集合列を返す
pairs([])->
[];
pairs([Z|T]) ->
lists:append(map(fun(X)->{Z,X} end, T), pairs(T)).
%ベクトル集合の距離&ベクトルペアの列を作る
pairdistance([])->
[];
pairdistance(ZS) ->
map(fun({X,Y})->{distance(X,Y), X, Y} end, pairs(ZS)).

%ベクトル集合の最大距離を持つ2点を選ぶ
maxdistance(ZS)->
Pairs = pairdistance(ZS),
{_, MX, MY} = maxdistance(Pairs, hd(Pairs)),
{MX, MY}.
maxdistance([], M)-> M;
maxdistance([Z|T], M)->
{D1,_,_} = M,
{D2,_,_} = Z,
if
D2 > D1 -> maxdistance(T, Z);
true -> maxdistance(T, M)
end.


%二点X,Yのどっちに近いかで、ベクトル集合ZSを二分割する。
splitvectorset(X,Y,ZS, Thre)->
{R,L} = splitvector(zip(map(fun({_,W})->which(X,Y,W) end, ZS), ZS), [], []),
IdRmv = fun ({_, V})->V end,
CR = median(map(IdRmv, R)),
CL = median(map(IdRmv, L)),
MinDist = lists:min([distance(X,CR)+distance(Y, CL), distance(Y, CR)+distance(X, CL)]),
if
MinDist > Thre ->
splitvectorset(CR, CL, ZS, Thre);
true -> {R, L}
end.

%公開関数その1
%ベクトル集合ZSを二分割する。
%集合ZSは、もう一つの公開関数read_lineで読み込んだデータです。
splitvectorset(ZS)->
Threshold = 0.1,
case length(ZS) of
0 -> [[],[]];
1 -> [ZS|[[]]];
2 -> [[lists:nth(1, ZS)], [lists:nth(2, ZS)]];
_ ->
{X, Y} = maxdistance(map(fun ({_, X})->X end, ZS)),
splitvectorset(X, Y, ZS, Threshold)
end.

%分析するデータファイルを読んで{ID, ベクトル}のリストとして返します。
% データファイル仕様
% データファイルはタブ区切り、先頭カラムはレコードID文字列、次カラム以降が数値データ
% 改行コードは\n
read_line(File) ->
{ok, IoDevice} = file:open(File, read),
ANS = read_line(IoDevice, 1, []),
file:close(IoDevice),
ANS.

read_line(IoDevice, LineNumber, Buf) ->
case io:get_line(IoDevice, "") of
eof ->
Buf;
Line ->
Data = string:tokens(lists:delete(10, Line), "\t"), % 10 = "\n"
if
length(Data) == 0 -> ok;
true -> Vec = {hd(Data), map(fun(X)-> parse_num(X) end, tl(Data))},
read_line(IoDevice, LineNumber + 1, lists:append(Buf, [Vec]))
end
end.

%数値を表す文字列をinteger又はfloatに変換します。
parse_num(X)->
case string:str(X, ".") of
0 -> list_to_integer(X);
_ -> list_to_float(X)
end.