疎行列ベクトル積について調べなくてはならなくなったので、自分の備忘録として残しておく。
行列ベクトル積というのは普通にのことを指すが、行列Aが疎である場合、つまり多くの要素が0である場合には、メモリを節約するために0出ない要素のみを記憶し、その位置をインデックスで保持する。
これを疎行列ベクトル積と呼び、SpMV (Sparse Matrix-Vector Multiplication)と呼ばれている。
例えば、以下のの行列と4列のベクトルの行列積を取る場合、オレンジの部分のみに値が入っており、それ以外の部分は0であるとする。
各要素に入っている値はグローバルなインデックスとする。

この場合、各行列の要素をすべて端に寄せ、その位置を覚える。
ptr: 各列の先頭要素が何番目のインデックスから始まるかを示す。idx: 各列において、それぞれの要素が何番目の行列の場所に入っているのかを示す。


これをC言語で書くとこんな感じになる。
void spmv(int r, const double* val, const int* idx, const double* x, const int* ptr, double* y) { for (int i = 0; i < r; i++) { int k; for (k = ptr[i]; k < ptr[i+1]; k++) { y[i] += val[k]*x[idx[k]]; } } }
これのテストをしたいのだが、RISC-Vのテストベンチマークセットに便利なScalaのコードを見つけたので、これを真似する。 というか、これを実行しようとしたらなぜかScalaがランタイムエラーを出したので、仕方がないのでRubyで書き換えた。
#!/usr/bin/env ruby m = ARGV[0].to_i n = ARGV[1].to_i approx_nnz = ARGV[2].to_i pnnz = approx_nnz.to_f/(m*n) idx = Array[] p = [0] m.times {|i| n.times {|j| if rand() < pnnz then idx.push (j) end } p.push (idx.size) } nnz = idx.size v = Array.new(n) { rand(1000) } d = Array.new(nnz) { rand(1000) } def printVec(t, name, data) printf("const %s %s[%d] = {", t, name, data.length) data.each_with_index {|d, index| print " " + d.to_s puts "," if index != data.size-1 } print("};\n\n") end def spmv(p, d, idx, v) y = Array.new for i in 0..(p.length-1) do yi = 0 limit = 0 if i == p.length-1 then limit = idx.size else limit = p[i+1] end for k in p[i]..(limit-1) do yi = yi + d[k]*v[idx[k]] end y[i] = yi end return y end printf("#define R %d\n", m) printf("#define C %d\n", n) printf("#define NNZ %d\n", nnz) printVec("double", "val", d) printVec("uint64_t", "idx", idx) printVec("double", "x", v) printVec("uint64_t", "ptr", p) printVec("double", "verify_data", spmv(p, d, idx, v))
$ ./spmv_gendata.rb 4 4 8
#define R 4
#define C 4
#define NNZ 9
const double val[9] = { 236,
776,
140,
252,
760,
5,
829,
723,
383};
const int idx[9] = { 0,
3,
0,
1,
3,
0,
1,
2,
3};
const double x[4] = { 568,
605,
16,
209};
const int ptr[5] = { 0,
1,
1,
4,
8};
const double verify_data[5] = { 134048,
0,
394164,
674793,
80047};