以下の内容はhttps://lpha-z.hatenablog.com/entry/2025/01/12/231500より取得しました。


行列乗算のさらなる最適化(コンシューマー向けGPU編)

以前の記事(行列乗算の最適化入門(コンシューマー向けGPU編) - よーる)で、RTX 4090向けの単精度行列積コードを作成し、cuBLASの98.4%である58.3 TFLOPSの性能を達成しました。 そのコードを更に最適化した結果、cuBLASを超える60.0 TFLOPSの性能を出せる単精度行列積コードを作れたので、その改良を紹介します。

実験条件等は前回の記事に準じます。

グローバルメモリからの読み出しをソフトウェアパイプライン化

以前の記事で紹介したコードは、以下のような構造でした。

for( int k0 = 0; k0 < N; k0 += SharedMemBlockK ) {
    __syncthreads();
    グローバルメモリからシェアードメモリへの転送
    __syncthreads();
    シェアードメモリからレジスタへの転送と、レジスタ上での行列積計算
}

この方式は、グローバルメモリの読み出しレイテンシを隠蔽することができないことが、性能上問題となります。 そこで、以下のような構造に変更します。

for( int k0 = 0; k0 < N; k0 += SharedMemBlockK ) {
    __syncthreads();
    レジスタからシェアードメモリへの転送
    __syncthreads();
    グローバルメモリからレジスタへの読み出し
    シェアードメモリからレジスタへの転送と、レジスタ上での行列積計算
}

なお、これはソフトウェアパイプラインなので、初回と最終回を特別扱いする必要があります。 カーネル部分を以下のように改変することで、59.0 TFLOPSを達成しました。

    float4 tmp_reg_a[LoopX/4];
    float4 tmp_reg_b[LoopZ/4][LoopY];
    for( int x = 0; x < LoopX/4; ++x ) {
        int il = ( x * ThreadsPerBlock + tid ) / (SharedMemBlockK/4);
        int kl = ( x * ThreadsPerBlock + tid ) % (SharedMemBlockK/4) * 4;
        int ig = i0 + il;
        int kg = 0 + kl;
        tmp_reg_a[x] = *reinterpret_cast<float4*>(&a[ig*N+kg]);
    }
    for( int z = 0; z < LoopZ/4; ++z )
    for( int y = 0; y < LoopY; ++y )
    {
        int kl = y * ThreadsPerBlockI + i2;
        int jl = ( z * ThreadsPerBlockJ + j2 ) * 4;
        int kg = 0 + kl;
        int jg = j0 + jl;
        tmp_reg_b[z][y] = *reinterpret_cast<float4*>(&b[kg*N+jg]);
    }

    for( int k0 = 0; k0 < N - SharedMemBlockK; k0 += SharedMemBlockK ) {
        __syncthreads();
        for( int x = 0; x < LoopX/4; ++x ) {
            int il = ( x * ThreadsPerBlock + tid ) / (SharedMemBlockK/4);
            int kl = ( x * ThreadsPerBlock + tid ) % (SharedMemBlockK/4) * 4;
            *reinterpret_cast<float4*>(&local_a[il][kl]) = tmp_reg_a[x];
        }
        for( int z = 0; z < LoopZ/4; ++z )
        for( int y = 0; y < LoopY; ++y )
        {
            int kl = y * ThreadsPerBlockI + i2;
            int jl = ( z * ThreadsPerBlockJ + j2 ) * 4;
            *reinterpret_cast<float4*>(&local_b[kl][jl]) = tmp_reg_b[z][y];
        }
        __syncthreads();

        for( int x = 0; x < LoopX/4; ++x ) {
            int il = ( x * ThreadsPerBlock + tid ) / (SharedMemBlockK/4);
            int kl = ( x * ThreadsPerBlock + tid ) % (SharedMemBlockK/4) * 4;
            int ig = i0 + il;
            int kg = (k0 + SharedMemBlockK) + kl;
            tmp_reg_a[x] = *reinterpret_cast<float4*>(&a[ig*N+kg]);
        }
        for( int z = 0; z < LoopZ/4; ++z )
        for( int y = 0; y < LoopY; ++y )
        {
            int kl = y * ThreadsPerBlockI + i2;
            int jl = ( z * ThreadsPerBlockJ + j2 ) * 4;
            int kg = (k0 + SharedMemBlockK) + kl;
            int jg = j0 + jl;
            tmp_reg_b[z][y] = *reinterpret_cast<float4*>(&b[kg*N+jg]);
        }

        for( int k1 = 0; k1 < SharedMemBlockK; ++k1 )
        for( int i1 = 0; i1 < RegisterBlockI; ++i1 )
        for( int j1 = 0; j1 < RegisterBlockJ/4; ++j1 )
        for( int j3 = 0; j3 < 4; ++j3 )
        {
            int il = i1 * ThreadsPerBlockI + i2;
            int jl = ( j1 * ThreadsPerBlockJ + j2 ) * 4 + j3;

            sum[j1*4+j3][i1] = fma( local_a[il][k1], local_b[k1][jl], sum[j1*4+j3][i1] );
        }
    }
    __syncthreads();
    for( int x = 0; x < LoopX/4; ++x ) {
        int il = ( x * ThreadsPerBlock + tid ) / (SharedMemBlockK/4);
        int kl = ( x * ThreadsPerBlock + tid ) % (SharedMemBlockK/4) * 4;
        *reinterpret_cast<float4*>(&local_a[il][kl]) = tmp_reg_a[x];
    }
    for( int z = 0; z < LoopZ/4; ++z )
    for( int y = 0; y < LoopY; ++y )
    {
        int kl = y * ThreadsPerBlockI + i2;
        int jl = ( z * ThreadsPerBlockJ + j2 ) * 4;
        *reinterpret_cast<float4*>(&local_b[kl][jl]) = tmp_reg_b[z][y];
    }
    __syncthreads();

    for( int k1 = 0; k1 < SharedMemBlockK; ++k1 )
    for( int i1 = 0; i1 < RegisterBlockI; ++i1 )
    for( int j1 = 0; j1 < RegisterBlockJ/4; ++j1 )
    for( int j3 = 0; j3 < 4; ++j3 )
    {
        int il = i1 * ThreadsPerBlockI + i2;
        int jl = ( j1 * ThreadsPerBlockJ + j2 ) * 4 + j3;

        sum[j1*4+j3][i1] = fma( local_a[il][k1], local_b[k1][jl], sum[j1*4+j3][i1] );
    }

ただし、cuobjdumpで見てみると、グローバルメモリの読み出しはループのかなり後の方で行われていて、有効にソフトウェアパイプライン化できていないようでした。 おそらく、レジスタの圧力に負けて後ろに追いやられている(生存期間を短くしようとする圧力がかかっている)ことが原因です。 命令移動を妨害するバリア的なものを置きたいところですが、性能が下がらないようなバリアを見つけることができませんでした。 レジスタ割り付けを手動で行う(具体的な番号はコンパイラ任せだが、変数をどこで上書きするかを明示する)ことでレジスタ圧を下げることができるようでしたが、すべてをコンパイラに任せるより性能が上がる割り付けを見つけることはできませんでした。

シェアードメモリの宣言順序変更

    __align__(16) __shared__ float local_a[SizeI*2][SharedMemBlockK]; // 8 KiB
    __align__(16) __shared__ float local_b[SharedMemBlockK*2][SizeJ]; // 32 KiB

    __align__(16) __shared__ float local_b[SharedMemBlockK*2][SizeJ]; // 32 KiB
    __align__(16) __shared__ float local_a[SizeI*2][SharedMemBlockK]; // 8 KiB

の順番に変更することで、59.0 TFLOPS → 59.3 TFLOPS となりました。 local_bのアラインが良くなったことが原因なのではないかと考えていますが、正直なところ、よくわからない性能向上です。

ダブルバッファリング再び

CUTLASSというライブラリのSGEMMはダブルバッファリングをしているようなので、それをまねしたりしてみるのですが、やはり性能が上がりません。 どうも、コードが複雑化することでコンパイラの最適化が働きにくくなっているような感じがあります。 バリアが一個なくなることは1%くらいの性能向上に寄与するようなので、原理的には性能向上するはずなのですが……。

また、レジスタへのプリフェッチもやってみましたが、レジスタ圧がさらに上がってしまって厳しいようです。 レジスタへのプリフェッチは、バリアの直後にLDS.128が配置されてレイテンシを隠蔽できていない問題を解決する方法で、ダブルバッファリングを適用した場合のみ可能です。 なお、レジスタへのプリフェッチ中に多少の行列乗算をやることが最適になりますが、それをやってもダブルバッファリングなしの場合より高い性能を得ることはできませんでした。

レジスタ圧の問題は、ワープで計算する行列の形状を正方形に近づけることで多少改善するのですが、コアレスアクセスではなくなるのか、かえって性能が下がってしまいます。 ワープ内の各スレッドがどの箇所を計算するのかを、もう一度見直す必要があるのかもしれません。

配列アクセスをポインタに変更

cuobjdumpを眺めていると、ループ内にそこそこの整数演算が入っていることがわかります。 なんでこんな古代の技術を持ち出さないといけないのかという感じですが、配列アクセスをポインタアクセスに変更することで、この整数演算を減らします。 可読性が非常に低下します。 また、シェアードメモリを二倍確保するのをやめたほうが性能が高かったので、*2を消しました。

__global__ void kernel( float* a, float* b, float* c ) {
    int i0 = blockIdx.y * SizeI;
    int j0 = blockIdx.x * SizeJ;
    int i2 = threadIdx.y;
    int j2 = threadIdx.x;
    int tid = threadIdx.y * ThreadsPerBlockJ + threadIdx.x;

    int it = tid / (SharedMemBlockK/4);
    int kt = tid % (SharedMemBlockK/4) * 4;
    int jt = j2 * 4;

    // 128 registers
    float sum[RegisterBlockJ][RegisterBlockI] = {};

    __align__(16) __shared__ float local_b[SharedMemBlockK*SizeJ]; // 16 KiB
    __align__(16) __shared__ float local_a[SizeI*SharedMemBlockK]; // 4 KiB

    // 40 registers
    float4 tmp_reg_a[LoopX/4];
    float4 tmp_reg_b[LoopZ/4][LoopY];

    float* p_ga = &a[(i0+it)*N+(kt)];
    float* p_gb = &b[(i2)*N+(j0+jt)];
    float* p_sa = &local_a[it*SharedMemBlockK+kt];
    float* p_sb = &local_b[i2*SizeJ+jt];
    float* p_la = &local_a[i2*SharedMemBlockK];
    float* p_lb = &local_b[jt];
    float* p_gc = &c[(i0+i2)*N+(j0+jt)];

    for( int x = 0; x < LoopX/4; ++x ) {
        int ix = x * ThreadsPerBlock / (SharedMemBlockK/4);
        int kx = x * ThreadsPerBlock % (SharedMemBlockK/4) * 4;
        tmp_reg_a[x] = *reinterpret_cast<float4*>(&p_ga[ix*N+kx]);
    }
    for( int z = 0; z < LoopZ/4; ++z )
    for( int y = 0; y < LoopY; ++y )
    {
        int ky = y * ThreadsPerBlockI;
        int jz = z * ThreadsPerBlockJ * 4;
        tmp_reg_b[z][y] = *reinterpret_cast<float4*>(&p_gb[ky*N+jz]);
    }

    for( int k0 = 0; k0 < N - SharedMemBlockK; k0 += SharedMemBlockK ) {
        __syncthreads();
        for( int x = 0; x < LoopX/4; ++x ) {
            int ix = x * ThreadsPerBlock / (SharedMemBlockK/4);
            int kx = x * ThreadsPerBlock % (SharedMemBlockK/4) * 4;
            *reinterpret_cast<float4*>(&p_sa[ix*SharedMemBlockK+kx]) = tmp_reg_a[x];
        }
        for( int z = 0; z < LoopZ/4; ++z )
        for( int y = 0; y < LoopY; ++y )
        {
            int ky = y * ThreadsPerBlockI;
            int jz = z * ThreadsPerBlockJ * 4;
            *reinterpret_cast<float4*>(&p_sb[ky*SizeJ+jz]) = tmp_reg_b[z][y];
        }
        __syncthreads();

        p_ga += SharedMemBlockK;
        p_gb += SharedMemBlockK * N;
        for( int x = 0; x < LoopX/4; ++x ) {
            int ix = x * ThreadsPerBlock / (SharedMemBlockK/4);
            int kx = x * ThreadsPerBlock % (SharedMemBlockK/4) * 4;
            tmp_reg_a[x] = *reinterpret_cast<float4*>(&p_ga[ix*N+kx]);
        }
        for( int z = 0; z < LoopZ/4; ++z )
        for( int y = 0; y < LoopY; ++y )
        {
            int ky = y * ThreadsPerBlockI;
            int jz = z * ThreadsPerBlockJ * 4;
            tmp_reg_b[z][y] = *reinterpret_cast<float4*>(&p_gb[ky*N+jz]);
        }

        for( int k1 = 0; k1 < SharedMemBlockK; ++k1 )
        for( int i1 = 0; i1 < RegisterBlockI; ++i1 )
        for( int j1 = 0; j1 < RegisterBlockJ/4; ++j1 )
        for( int j3 = 0; j3 < 4; ++j3 )
        {
            int il = i1 * ThreadsPerBlockI;
            int jl = j1 * ThreadsPerBlockJ * 4 + j3;

            sum[j1*4+j3][i1] = fma( p_la[il*SharedMemBlockK+k1], p_lb[k1*SizeJ+jl], sum[j1*4+j3][i1] );
        }
    }

    __syncthreads();
    for( int x = 0; x < LoopX/4; ++x ) {
        int ix = x * ThreadsPerBlock / (SharedMemBlockK/4);
        int kx = x * ThreadsPerBlock % (SharedMemBlockK/4) * 4;
        *reinterpret_cast<float4*>(&p_sa[ix*SharedMemBlockK+kx]) = tmp_reg_a[x];
    }
    for( int z = 0; z < LoopZ/4; ++z )
    for( int y = 0; y < LoopY; ++y )
    {
        int ky = y * ThreadsPerBlockI;
        int jz = z * ThreadsPerBlockJ * 4;
        *reinterpret_cast<float4*>(&p_sb[ky*SizeJ+jz]) = tmp_reg_b[z][y];
    }
    __syncthreads();

    for( int k1 = 0; k1 < SharedMemBlockK; ++k1 )
    for( int i1 = 0; i1 < RegisterBlockI; ++i1 )
    for( int j1 = 0; j1 < RegisterBlockJ/4; ++j1 )
    for( int j3 = 0; j3 < 4; ++j3 )
    {
        int il = i1 * ThreadsPerBlockI;
        int jl = j1 * ThreadsPerBlockJ * 4 + j3;

        sum[j1*4+j3][i1] = fma( p_la[il*SharedMemBlockK+k1], p_lb[k1*SizeJ+jl], sum[j1*4+j3][i1] );
    }

    for( int i1 = 0; i1 < RegisterBlockI; ++i1 )
    for( int j1 = 0; j1 < RegisterBlockJ/4; ++j1 )
    for( int j3 = 0; j3 < 4; ++j3 )
    {
        int i = i1 * ThreadsPerBlockI;
        int j = j1 * ThreadsPerBlockJ * 4 + j3;
        p_gc[i*N+j] += sum[j1*4+j3][i1];
    }
}

これで59.5 TFLOPSになりました。

コンパイルオプションでガチャを引く

ここまでのコードはnvcc main.cu -arch compute_86 -code sm_86 -lcublas --maxrregcount=255 -std=c++17コンパイルしていました。 この辺のオプションを変えるとコンパイラの出すコードが変わるので、性能が高くなるオプションを探します。 と言っても実際には何もオプションを指定しないnvcc main.cu -lcublas -std=c++17が最適で、その時に60.0 TFLOPSを達成しました。 一番良い時で、60.138 TFLOPSでした。

まとめ

RTX 4090上でcuBLAS(かなりブレがあるが、59.2~59.4 TFLOPSくらい。最大で59.981 TFLOPSだった)よりも高い性能である60.0 TFLOPSを達成する行列積コードを作ることができました。 コードの工夫としてはグローバルメモリからの読み出しをソフトウェアパイプライン化することが重要でした。 あとはコンパイラの機嫌をうかがうのが大事(?)なようです。




以上の内容はhttps://lpha-z.hatenablog.com/entry/2025/01/12/231500より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14