判別する〜集団の平均点の垂直2等分線
[b][size=150]<相関比>[/size][/b][br]2つの集団がまざって分布しているとき、どこに境界線をひくのかという問題が判別問題です。[br]分散は平均との差の2乗の平均ですが、個数で割る前の2乗和を変動といいます。[br]平均からの距離の2乗を集団で合計したものです。[br]集団Pと集団Nのそれぞれの平均をm[sub]p[/sub], m[sub]n[/sub],とします。全体の平均をmとします。[br]全体変動(全変動)はT=Σ(x-m)[sup]2[/sup]、[br]集団内の変動の和(群内変動)はI=Σ(x[sub]p[/sub]-m[sub]p[/sub])[sup]2[/sup]+Σ(x[sub]n[/sub]ーm[sub]n[/sub])[sup]2[/sup][br]集団の平均どうしの変動和(群間変動)はJ=Σ(m − m[sub]p[/sub])[sup]2[/sup]+Σ(m - m[sub]n[/sub])[sup]2[/sup][br][color=#0000ff][b]T=I+J(全変動=群内変動+群間変動)となります。[br][/b][/color][br]TにしめるJの値が大きいと、それぞれの集団のかたまりが強いことになります。[br]2集団の分離がよりはっきりしています。[br]相関比=群間変動の比率=J/Tまたは、J/Iの値が判別の強さを表しているともいえます。[br][br]<課題1>[br]男子10人女子10人の身長x、体重yのデータから、男女を判別するx1とx2の関係式をさがす。[br][table][br][tr][td]番号[/td][td]1[/td][td]2[/td][td]3[/td][td]4[/td][td]5[/td][td]6[/td][td]7[/td][td]8[/td][td]9[/td][td]10[/td][/tr][br][tr][td]x[/td][td]151.1[/td][td]155.9[/td][td]159.4[/td][td]154.6[/td][td]162.9[/td][td]158.3[/td][td]171.7[/td][td]160.8[/td][td]153.4[/td][td]161.2[/td][/tr][br][tr][td]y[/td][td]43.7[/td][td]46.2[/td][td]49.5[/td][td]56.3[/td][td]50.9[/td][td]63.5[/td][td]59.8[/td][td]51.7[/td][td]58.3[/td][td]46.8[/td][/tr][br][tr][td]性[/td][td]F[/td][td]F[/td][td]F[/td][td]F[/td][td]F[/td][td]F[/td][td]F[/td][td]F[/td][td]F[/td][td]F[/td][/tr][br][/table][br][table][br][tr][td]番号[/td][td]11[/td][td]12[/td][td]13[/td][td]14[/td][td]15[/td][td]16[/td][td]17[/td][td]18[/td][td]19[/td][td]20[/td][/tr][br][tr][td]x[/td][td]184.9[/td][td]181.3[/td][td]171.4[/td][td]168.6[/td][td]162.3[/td][td]179.9[/td][td]179.5[/td][td]173.4[/td][td]167.9[/td][td]177.9[/td][/tr][br][tr][td]y[/td][td]75.5[/td][td]78.9[/td][td]66.2[/td][td]61.0[/td][td]55.7[/td][td]80.6[/td][td]66.1[/td][td]61.2[/td][td]61.3[/td][td]77.2[/td][/tr][br][tr][td]性[/td][td]M[/td][td]M[/td][td]M[/td][td]M[/td][td]M[/td][td]M[/td][td]M[/td][td]M[/td][td]M[/td][td]M[/td][/tr][br][/table][br]
[color=#0000ff][b]男子群p、女子群nを判別する直線w1x1+w2x2=bをw={{w1},{w2}},x={{x1},{x2}}と列ベクトルをとると、[br]法線ベクトルへの正射影の長さがbとなるベクトル方程式w[sup]t[/sup] x=bとなるね。[br]そのwとbを求めたい。[br][/b][/color]x=(x1,x2)のデータが原点を通り法線ベクトルを方向ベクトルとする「直線上L」に射影される。[br]男子x[sub]p[/sub]の平均ベクトルはm[sub]p[/sub],女子x[sub]n[/sub]の平均ベクトルはm[sub]n[/sub].[br]平均ベクトルの直線L上の射影はw' m[sub]p[/sub], w' m[sub]n[/sub]となる。'は随伴行列(共役複素数の転置)。[br]群間J=(w' m[sub]p[/sub] - w' m[sub]n[/sub])[sup]2[/sup]=(w' (m[sub]p[/sub]- m[sub]n[/sub]))(w'( m[sub]p[/sub] -m[sub]n[/sub]))'=w' [b][color=#0000ff](m[sub]p[/sub]-m[sub]n[/sub])(m[sub]p[/sub]-m[sub]n[/sub])'[/color] [/b]w = w' [b][color=#0000ff]B[/color][/b] wとおける。[br]群内I=Σ(w' m[sub]p[/sub]-w' x[sub]p[/sub])[sup]2[/sup]+Σ(w' m[sub]n[/sub]- w' x[sub]n[/sub])[sup]2[/sup]=w' ([color=#0000ff][b]Σ(m[sub]p[/sub]-x[sub]p[/sub])(m[sub]p[/sub]-x[sub]p[/sub])' + Σ(m[sub]n[/sub]-x[sub]n[/sub])(m[sub]n[/sub]-x[sub]n[/sub])'[/b][/color] ) w=w' [color=#0000ff][b]W [/b][/color]wとおく。[br]J/I=J(w)= w' B w / w' W w が最大になるwを求めるにはJ(w)のwでの偏微分=0から、[br]結局、B w =λ W wとなるwとλを求めればよいね。[br]Wの逆行列をかけると、inv(W)B w =λ w という固有方程式ができる。[br]だから、行列inv(W)Bの[b]固有値[/b]をα、[b]固有ベクトル[/b]をwとすると、.......(途中略)。[color=#0000ff][b]w=inv(W)(m[sub]p[/sub]-m[sub]n[/sub])[/b][/color].[br]wは2つの群の平均m[sub]p[/sub]、m[sub]n[/sub]の中心を2等分する直線の傾き、つまり、平均の差のベクトルに直交して、[br]中点を通ればよいのです。正射影の長さbは直線L上での中点だから、[b]b[/b]=[color=#0000ff][b]1/2(w'(m[sub]p[/sub]+m[sub]n[/sub])) [/b][/color][br]これで判別直線w[sup]t[/sup] x=bと決まるね。
[size=150][b]<julia>[br][/b][/size][b][color=#38761d]using[/color] LinearAlgebra[br][color=#38761d]using[/color] Statistics[br][/b]x11=[184.9,181.3,171.4,168.6,162.3,179.9,179.5,173.4,167.9,177.9][br]x12=[151.1,155.9,159.4,154.6,162.9,158.3,171.7,160.8,153.4,161.2][br]x21=[75.5,78.9,66.2,61.0,55.7,80.6,66.1,61.2,61.3,77.2][br]x22=[43.7,46.2,49.5,56.3,50.9,63.5,59.8,51.7,58.3,46.8][br]m1=[mean(x11) mean(x21)][br]m2=[mean(x12) mean(x22)][br]x1=[[x1 x2] for (x1,x2) in zip(x11,x21)] [br]x2=[[x1 x2] for (x1,x2) in zip(x12,x22)][br]mx1=[transpose([m1[1]-p m1[2]-q]) for (p,q) in x1] [br]mx2=[transpose([m2[1]-p m2[2]-q]) for (p,q) in x2][br]s1=sum([[x[1]^2 x[1]*x[2] ;x[1]*x[2] x[2]^2] for x in mx1])[br]s2=sum([[x[1]^2 x[1]*x[2] ;x[1]*x[2] x[2]^2] for x in mx2])[br][b]W=s1+s2[br]invW=inv(W)[br]w[/b] = normalize([b]invW *(m1-m2)'[/b])[br][b]mid=(m1+m2) *0.5[br][/b][color=#1155cc]#w[1]x1+w[2]x2-b=0 slope= -w[1]/w[2] y_intersect=b/w[2][br][/color]sl=floor(-w[1]*10/w[2])/10[br]yi= mid[2]+ mid[1]*(-sl)[br]print("sl=",sl,",yi=",yi)[br][color=#1155cc]#========================================================[br]sl=-3.1,yi=577.662[/color]
[size=150][b]<julia>[br][/b][color=#38761d]#[b]LDA線形判別分析[/b](Linear Discriminant Analysis)のモジュールを使う。[br][size=150]#Xp,Xnを行列で渡しやすくするためDataFramesモジュールを使う。[br] w'x+b=0を解くので、bの符号は上記と反対になるので注意。[br][/size][/color][/size][color=#38761d][b]using[/b][/color] [b]LinearAlgebra[/b][br][color=#38761d][b]using[/b][/color] [b]DataFrames[/b] [br][color=#38761d][b]using[/b][/color] [b]MultivariateStats[br][/b]xp=[184.9,181.3,171.4,168.6,162.3,179.9,179.5,173.4,167.9,177.9][br]yp=[75.5,78.9,66.2,61.0,55.7,80.6,66.1,61.2,61.3,77.2][br]xn=[151.1,155.9,159.4,154.6,162.9,158.3,171.7,160.8,153.4,161.2][br]yn=[43.7,46.2,49.5,56.3,50.9,63.5,59.8,51.7,58.3,46.8][br]p=DataFrame(datax=xp,datay=yp)[br]n=DataFrame(datax=xn,datay=yn)[br][b][color=#38761d]#データをデータフレームから取り出して行列として渡すためにfloat.(Matrix()')でくるむ。[br][/color][/b]Xp = float.([b]Matrix[/b](p)');[br]Xn = float.([b]Matrix[/b](n)');[br]f = [b]MultivariateStats[/b].fit([b]LinearDiscriminant[/b], Xp, Xn);[br][b]w=f.w[br]b=f.b[br][/b]sl=-w[1]/w[2][br]yi=-b/w[2][br]print("sl=",sl,",yi=",yi)[br][color=#6d9eeb]#========================================================[br]sl=-3.022262684795434,yi=564.6938610775744[/color][br]
[b][color=#38761d]# PyPlotを使って視覚化する。[br]using[/color] PyPlot[/b][br]xp=[184.9,181.3,171.4,168.6,162.3,179.9,179.5,173.4,167.9,177.9][br]yp=[75.5,78.9,66.2,61.0,55.7,80.6,66.1,61.2,61.3,77.2][br]xn=[151.1,155.9,159.4,154.6,162.9,158.3,171.7,160.8,153.4,161.2][br]yn=[43.7,46.2,49.5,56.3,50.9,63.5,59.8,51.7,58.3,46.8][br][br][b][color=#38761d]# 男子の点[/color][/b][br]plt.[b]scatter[/b](xp,yp)[br]m = ["m","m","m","m","m","m","m","m","m","m"][br]for (i, txt) in enumerate(m)[br] plt.[b]annotate[/b](txt, (xp[i], yp[i]))[br]end[br][color=#38761d][b]# 女子の点[/b][/color][br]plt.[b]scatter[/b](xn,yn)[br]f = ["f","f","f","f","f","f","f","f","f","f"][br]for (i, txt) in enumerate(f)[br] plt.annotate(txt, (xn[i], yn[i]))[br]end[br][color=#38761d][b]# 判別直線[/b][/color][br]xs=150:185[br][color=#38761d][b]#+はブロードキャストのため.+にする。[br][/b][/color]ys=(xs.*sl).+yi[br]plt.[b]plot[/b](xs,ys)
[b][size=150]<Python>[/size][/b][br][b][color=#38761d]import[/color] numpy as np[br][color=#38761d]from[/color] statistics [color=#38761d]import[/color] mean[br][color=#38761d]from[/color] numpy.linalg [color=#38761d]import[/color] inv,norm[br][/b]xp=np.array([184.9, 181.3, 171.4, 168.6, 162.3, 179.9, 179.5, 173.4 ,167.9 ,177.9])[br]yp=np.array([75.5,78.9, 66.2, 61.0, 55.7, 80.6, 66.1 , 61.2, 61.3, 77.2 ])[br]xn=np.array([151.1, 155.9, 159.4, 154.6, 162.9,158.3,171.7,160.8,153.4,161.2])[br]yn=np.array([43.7, 46.2, 49.5, 56.3, 50.9, 63.5, 59.8, 51.7, 58.3, 46.8])[br]mp=np.array([mean(xp) ,mean(yp)])[br]mn=np.array([mean(xn) ,mean(yn)])[br]xp=[np.array([x1 ,x2]) for (x1,x2) in zip(xp,yp)] [br]xn=[np.array([x1 ,x2]) for (x1,x2) in zip(xn,yn)][br]mxp=[np.array([mp[0]-p, mp[1]-q]) for (p,q) in xp] [br]mxn=[np.array([mn[0]-p, mn[1]-q]) for (p,q) in xn][br]sp=sum([np.array([[x[0]**2 ,x[0]*x[1]] ,[x[0]*x[1], x[1]**2]]) for x in mxp])[br]sn=sum([np.array([[x[0]**2 ,x[0]*x[1]] ,[x[0]*x[1], x[1]**2]]) for x in mxn])[br]W=sp+sn[br]invW=inv(W)[br]w = invW @(mp-mn)[br]w = w/norm(w)[br]mid=(mp+mn) *0.5[br][color=#38761d]#w[1]x1+w[2]x2-b=0 slope= -w[1]/w[2] y_intersect=b/w[2][br][/color]sl=int(-w[0]*10/w[1])/10[br]yi= mid[1]+ mid[0]*(-sl)[br]print("sl=",sl,",yi=",yi)[br][color=#6d9eeb]#========================================================[br][/color][color=#3d85c6]sl= -3.0 ,yi= 560.98[br][/color][br]