テンソルに低タッカーランク性を課して,欠損値を補完する手法 HaLRTC を実装しました.
参考にした論文はこれ.Algorithm 4 です.
pubmed.ncbi.nlm.nih.gov
引用数すげえええ.こういう論文をかけるようになりたいものですね.再現するのは図6の実験です.
アルゴリズム本体はこんな感じ.特に苦労もなくかけました.
ただ,YとMの初期値をどう決めたらよいか論文には書いてないように思います.とりあえず,全部ゼロで初期化してます.
https://ieeexplore.ieee.org/document/6138863
function Dtau(X,tau)
D = svd(X)
return D.U * diagm( max.(D.S .- tau, 0) ) * D.Vt
end
"""
HaLRTC: High Accuracy Low Rank Tensor Completion
See Algorithm 4 in the
[original paper](https://ieeexplore.ieee.org/document/6138863)
# Aruguments
- 'X' : input tensor
- 'W' : binary tensor if X_ijk is missing then W_ijk = 0, otherwise 1
- 'rho' : hyper parameter
- 'Xgt' : ground truth tensor for printing verbose
"""
function HaLRTC(X, W;rho=1.0e-5, iter_max=1000, verbose=true,Xgt=NaN)
idxs_missing = findall( W .== 0 )
Xhat = X
D = ndims(X)
J = size(X)
Y = Vector{Array{Float64,D}}(undef,D)
M = Vector{Array{Float64,D}}(undef,D)
for d = 1:D
Y[d] = zeros( J )
M[d] = zeros( J )
end
alpha = 1.0/D
for iter = 1:iter_max
for d = 1:D
Xd = tenmat(X,d)
Ydd = tenmat(Y[d],d)
M[d] = matten( Dtau( Xd .+ Ydd/rho, alpha/rho ), d, [J...])
end
tmp = ( sum(M) .- sum(Y)./rho ) ./ D
Xhat .= (1 .- W) .* tmp + W .* X
for d = 1:D
Y[d] .-= rho.*( M[d] .- X )
end
if verbose
if mod(iter,5) == 0
rms = norm(Xhat .- Xgt)/norm(Xgt)
@show (iter, rms)
end
end
end
return Xhat
end
実験用に欠損を含む低タッカーランクをつくります.Jがテンソルのサイズ,Rがタッカーランクです.総和がテンソルの要素数と同じになるように規格化しておきます.Wは重みテンソルで,欠損してるインデックスは0,それ以外は1のバイナリテンソルです.sr が観測率ですね.sr=0 は全て欠損,sr=100は欠損なしを意味します.
function get_low_tucker_tensor(J,R)
G = rand(R...)
D = length(J)
U = Vector{Matrix{Float64}}(undef,D)
for d = 1:D
U[d] = ( -0.5 .+ rand(J[d],R[d]) )
end
T = ttm(G, [U...], [1:D;])
T = T .* ( prod(J) ./ sum(T))
return T
end
function generate_weight(T;sr=30)
prob = zeros(100)
prob[1:sr] .= 1
W = rand(prob, size(T))
return W
end
欠損している部分を平均μ,分散1の正規分布で初期値を決めておきます.
μは欠損していない部分の平均.これは論文に書いてあった初期位置の決め方です.
function init_missing_val!(X,W)
mu = sum( W .* X ) / sum(W)
prob = Normal(mu, 1)
idxs_missing = findall( W .== 0 )
X[idxs_missing] = rand(prob,length(idxs_missing))
return X
end
実行コードは以下.論文の問題設定に合わせて観測率20%にしておきます.
function main()
r = 2
Xgt = get_low_tucker_tensor([50,50,50,50],[r,r,r,r])
W = generate_weight(Xgt, sr=20)
Xin = init_missing_val!(deepcopy(Xgt), W)
for rho in [1.0e-6, 1.0e-7, 1.0e-8, 1.0e-9]
Xpre = HaLRTC(Xin, W, rho=rho, iter_max=50, verbose=true, Xgt=Xgt)
@show (r, rho, RSE(Xgt,Xpre))
end
end
RSE(a,b) = norm(a - b) / norm(a)
main()
今回は学習の様子を知りたいので,HaLRTCにXgt(ground truth)を渡してます.これは普通の問題設定では無理なので注意(verboseでスコアを出す以外では使ってない).実行結果は以下.
% rho = 1.0e-6
(iter, rms) = (5, 0.8869150573563238)
(iter, rms) = (10, 0.8787654368905992)
(iter, rms) = (15, 0.8706546764827426)
(iter, rms) = (20, 0.8625800836366997)
(iter, rms) = (25, 0.8545393221335581)
(iter, rms) = (30, 0.8465303461954946)
(iter, rms) = (35, 0.8385513502138886)
(iter, rms) = (40, 0.8306007297171744)
(iter, rms) = (45, 0.822677050591317)
(iter, rms) = (50, 0.8147790244567812)
(r, rho, RSE(Xgt, Xpre)) = (2, 1.0e-6, 0.8147790244567812)
% rho = 1.0e-7
(iter, rms) = (5, 0.7370755291251271)
(iter, rms) = (10, 0.6611151830246284)
(iter, rms) = (15, 0.5865472507974067)
(iter, rms) = (20, 0.5132093749028765)
(iter, rms) = (25, 0.4410342512627796)
(iter, rms) = (30, 0.3700098409582638)
(iter, rms) = (35, 0.3001540041695475)
(iter, rms) = (40, 0.2315108108075831)
(iter, rms) = (45, 0.16415926359343516)
(iter, rms) = (50, 0.09823642969477969)
(r, rho, RSE(Xgt, Xpre)) = (2, 1.0e-7, 0.09823642969477969)
% rho = 1.0e-8
(iter, rms) = (5, 0.11760666100628855)
(iter, rms) = (10, 0.10256307545489725)
(iter, rms) = (15, 0.03393611809759228)
(iter, rms) = (20, 0.013397645459977311)
(iter, rms) = (25, 0.019970917994870968)
(iter, rms) = (30, 0.010004052083471028)
(iter, rms) = (35, 0.0013604313052622036)
(iter, rms) = (40, 0.0033711415897143713)
(iter, rms) = (45, 0.002407736831070906)
(iter, rms) = (50, 0.00060861557570968)
(r, rho, RSE(Xgt, Xpre)) = (2, 1.0e-8, 0.00060861557570968)
% rho = 1.0e-9
(iter, rms) = (5, 0.8434343422817917)
(iter, rms) = (10, 0.30375723857457604)
(iter, rms) = (15, 0.16121070156092238)
(iter, rms) = (20, 0.09762730684576809)
(iter, rms) = (25, 0.0447844565537104)
(iter, rms) = (30, 0.014794610766144974)
(iter, rms) = (35, 0.01559292614989365)
(iter, rms) = (40, 0.011395916029972048)
(iter, rms) = (45, 0.0062374608175992325)
(iter, rms) = (50, 0.004374149421615887)
(r, rho, RSE(Xgt, Xpre)) = (2, 1.0e-9, 0.004374149421615887)
まぁちゃんと下がってますし,良いんじゃないでしょうか.どうも rho=1.0e-6 のときはうまくいきませんね.論文では10回繰り返した,と書いてあったのですが,これはきっと10回繰り返して最もよかったものをプロットした,ということだと思うので,まぁ入力テンソルを変えれば,Fig.6 くらい(0.55くらい?)までは下がるのかな.
ついでにこないだの SiLRTC も貼っておきます.
"""
SiLRTC: Simple Low Rank Tensor Completion
See Algorithm 1 in the
[original paper](https://ieeexplore.ieee.org/document/6138863)
# Aruguments
- 'X' : input tensor
- 'W' : binary tensor if X_ijk is missing then W_ijk = 0, otherwise 1
- 'tau' : hyper parameter
- 'Xgt' : ground truth tensor for printing verbose
"""
function SiLRTC(X, W, tau=1.0e2; iter_max=1000, verbose=false,Xgt=NaN)
idxs_missing = findall( W .== 0 )
D = ndims(X)
beta = ones(D)
sizeX = size(X)
sum_beta = sum(beta)
M = Vector{Matrix{Float64}}(undef,D)
for iter = 1:iter_max
for d = 1 : D
Xd = tenmat(X,d)
M[d] = Dtau(Xd, tau)
end
foldM = zeros(size(X))
for d = 1:D
foldM .+= ( beta[d] .* matten( M[d], d, [size(X)...] ))
end
foldM = foldM ./ sum_beta
X[idxs_missing] .= foldM[idxs_missing]
if verbose
if iter % 200 == 0
rms = norm(X .- Xgt)/norm(Xgt)
@show (iter, rms)
end
end
end
return X
end
function main()
r = 10
Xgt = get_low_tucker_tensor([50,50,50],[r,r,r])
W = generate_weight(Xgt, sr=30)
Xin = init_missing_val!(deepcopy(Xgt), W)
for tau in [10,100,1000]
Xpre = SiLRTC(deepcopy(Xin), W, tau, iter_max=2000, verbose=true, Xgt=Xgt)
@show (r, tau, RSE(Xgt,Xpre))
end
end
実行結果は以下.tau が小さいと(モデルが大きくて,最適化するパラメータ数が多いので)収束するのにもっともっと iter が必要ということですな.
(iter, rms) = (200, 0.8306453721340777)
(iter, rms) = (400, 0.8247437581588134)
(iter, rms) = (600, 0.8188849564191533)
(iter, rms) = (800, 0.8130680035360011)
(iter, rms) = (1000, 0.8072919985188617)
(iter, rms) = (1200, 0.8015560976208085)
(iter, rms) = (1400, 0.7958595097055471)
(iter, rms) = (1600, 0.7902014920666872)
(iter, rms) = (1800, 0.7845813466478869)
(iter, rms) = (2000, 0.7789984166179046)
(r, tau, RSE(Xgt, Xpre)) = (10, 10, 0.7789984166179046)
(iter, rms) = (200, 0.7789920277327693)
(iter, rms) = (400, 0.7250842431553487)
(iter, rms) = (600, 0.6743908099006761)
(iter, rms) = (800, 0.6266246907129802)
(iter, rms) = (1000, 0.5815967124956913)
(iter, rms) = (1200, 0.5391698153982848)
(iter, rms) = (1400, 0.4992340492634431)
(iter, rms) = (1600, 0.4616920768892733)
(iter, rms) = (1800, 0.42645050231781445)
(iter, rms) = (2000, 0.3934146886184567)
(r, tau, RSE(Xgt, Xpre)) = (10, 100, 0.3934146886184567)
(iter, rms) = (200, 0.39309150992974745)
(iter, rms) = (400, 0.16002034419257943)
(iter, rms) = (600, 0.03314117313566921)
(iter, rms) = (800, 0.0224660080703799)
(iter, rms) = (1000, 0.022466003544853275)
(iter, rms) = (1200, 0.022466003544851655)
(iter, rms) = (1400, 0.0224660035448515)
(iter, rms) = (1600, 0.022466003544851582)
(iter, rms) = (1800, 0.02246600354485149)
(iter, rms) = (2000, 0.022466003544851818)
(r, tau, RSE(Xgt, Xpre)) = (10, 1000, 0.022466003544851818)
関連して Julia の実装を見つけたのではっておきます.
github.com