Re: 答复: [RISC-V] [tech-vector-ext] The scenarios of GEMM for u/int8 data


Linjie Yu
 

HiDavid

 

Can we see the git of your work?

My code has not been upload to git, and I will show it in the mail.

            Does this mean the 32 vector registers are not enough,

or that the number of elements for the given input vector length are not enough?

Yes, for the width should be widen 4 times.

 

       With a "temporary working vector" this new instruction is a combination of the old with any "insert scalar into element" instruction [such as  vrgather.vv splatt with mask ].

 

To use vrgather.vv, the 128 bits const is complex to init.

 

Next, I will show my code:

 

Firstly,  the C code:

        int sum[8];

        for (int j = 0; j < 8; j++) {

            sum[j] = bias_ptr[j];

        }

        for (int j = 0; j < inch_16; j++) {

            for (int k = 0; k < 16; k++) {

                for (int x = 0; x < 8; x++) {

                   sum[x] += in_ptr[k] * f0[k + 16 * x];

                }

            }

            in_ptr += 16;

            f0 += 16 * 8;

        }

        for (int j = 0; j < 8; j++) {

            int lshift = -shift_value[j + i];

            if (lshift > 0) {

                sum[j] = (sum[j] + (1 << (lshift - 1))) >> lshift;

            } else {

                sum[j] = sum[j] << (-lshift);

            }

            out_ptr[j] = (char)sum[j];

        }

 

1.     vdot.vv+vredsum.vs  (the tail process is so complex)

                    "vsetvli        zero, zero, e64, m8\n\t"

                    "vxor.vv        v0, v0, v0\n\t"

                    "beqz           %4, 1f\n\t"

 

                    "0: \n\t"

                    "vsetvli        zero, zero, e8, m1\n\t"

                    "vle.v          v8, (%0)\n\t"

                    "addi           %0, %0, 16\n\t"

                    "vle.v          v9, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v10, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v11, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v12, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v13, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v14, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v15, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v16, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

 

                    "vsetvli        zero, zero, e32, m1, d4\n\t"

                    "vdot.vv        v0, v8, v9\n\t"

                    "vdot.vv        v1, v8, v10\n\t"

                    "vdot.vv        v2, v8, v11\n\t"

                    "vdot.vv        v3, v8, v12\n\t"

                    "vdot.vv        v4, v8, v13\n\t"

                    "addi           %4, %4, -1\n\t"

                    "vdot.vv        v5, v8, v14\n\t"

                    "vdot.vv        v6, v8, v15\n\t"

                    "vdot.vv        v7, v8, v16\n\t"

                    "bnez           %4, 0b\n\t"

 

                  "1: \n\t"

                    "vsetvli        zero, zero, e64, m8\n\t"

                    "vxor.vv        v8, v8, v8\n\t"

                    "vsetvli        zero, zero, e32, m1\n\t"

                    "vwredsum.vs    v8, v0, v8\n\t"

                    "vwredsum.vs    v9, v1, v9\n\t"

                    "vwredsum.vs    v10, v2, v10\n\t"

                    "vwredsum.vs    v11, v3, v11\n\t"

                    "vwredsum.vs    v12, v4, v12\n\t"

                    "vwredsum.vs    v13, v5, v13\n\t"

                    "vwredsum.vs    v14, v6, v14\n\t"

                    "vwredsum.vs    v15, v7, v15\n\t"

 

 

2.     vwmul + vwredsum.vs (vwredsum.vs used in the for loop)

                   "vsetvli        zero, zero, e64, m8\n\t"

                    "vxor.vv        v0, v0, v0\n\t"

                    "beqz           %4, 1f\n\t"

 

                    "0: \n\t"

                    "vsetvli        zero, zero, e8, m1\n\t"

                    "vle.v          v8, (%0)\n\t"

                    "addi           %0, %0, 16\n\t"

                    "vle.v          v9, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v10, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v11, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v12, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v13, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

 

                    "vwmul.vv       v14, v8, v9\n\t"

                    "vwmul.vv       v16, v8, v10\n\t"

                    "vle.v          v9, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vwmul.vv       v18, v8, v11\n\t"

                    "vle.v          v10, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vwmul.vv       v20, v8, v12\n\t"

                    "vle.v          v11, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

 

                    "vwmul.vv       v22, v8, v13\n\t"

                    "vwmul.vv       v24, v8, v9\n\t"

                    "vwmul.vv       v26, v8, v10\n\t"

                    "vwmul.vv       v28, v8, v11\n\t"

 

                    "vsetvli        zero, zero, e16, m2\n\t"

                    "vwredsum.vs    v0, v14, v0\n\t"

                    "vwredsum.vs    v1, v16, v1\n\t"

                    "vwredsum.vs    v2, v18, v2\n\t"

                    "addi           %4, %4, -1\n\t"

                    "vwredsum.vs    v3, v20, v3\n\t"

                    "vwredsum.vs    v4, v22, v4\n\t"

                    "vwredsum.vs    v5, v24, v5\n\t"

                    "vwredsum.vs    v6, v26, v6\n\t"

                    "vwredsum.vs    v7, v28, v7\n\t"

"bnez           %4, 0b\n\t"

 

3.     vwmul + vwredsum.vs(new)

                   "vsetvli        zero, zero, e16, m2\n\t"

                    "vxor.vv        v2, v2, v2\n\t"

                    "beqz           %4, 1f\n\t"

 

                    "0: \n\t"

                    "vsetvli        zero, zero, e8, m1\n\t"

                    "vle.v          v8, (%0)\n\t"

                    "addi           %0, %0, 16\n\t"

                    "vle.v          v9, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v10, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v11, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v12, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vle.v          v13, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

 

                    "vwmul.vv       v14, v8, v9\n\t"

                    "vwmul.vv       v16, v8, v10\n\t"

                    "vle.v          v9, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vwmul.vv       v18, v8, v11\n\t"

                    "vle.v          v10, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

                    "vwmul.vv       v20, v8, v12\n\t"

                    "vle.v          v11, (%1)\n\t"

                    "addi           %1, %1, 16\n\t"

 

                    "vwmul.vv       v22, v8, v13\n\t"

                    "vwmul.vv       v24, v8, v9\n\t"

                    "vwmul.vv       v26, v8, v10\n\t"

                    "vwmul.vv       v28, v8, v11\n\t"

 

                    "vsetvli        zero, zero, e16, m2\n\t"

                    "vwredsum.vs    v2, v14, v2, 0\n\t"

                    "vwredsum.vs    v2, v16, v2, 1\n\t"

                    "vwredsum.vs    v2, v18, v2, 2\n\t"

                    "addi           %4, %4, -1\n\t"

                    "vwredsum.vs    v2, v20, v2, 3\n\t"

                    "vwredsum.vs    v3, v22, v3, 0\n\t"

                    "vwredsum.vs    v3, v24, v3, 1\n\t"

                    "vwredsum.vs    v3, v26, v3, 2\n\t"

                    "vwredsum.vs    v3, v28, v3, 3\n\t"

 

                    "bnez           %4, 0b\n\t"

 

 

All of them are shown above. Any suggestions are welcomed.

 

Yours

Damon

 

 

 

 

发件人: tech-vector-ext@... <tech-vector-ext@...> 代表 David Horner
发送时间: 20201211 17:32
收件人: tech-vector-ext@...
主题: Re: [RISC-V] [tech-vector-ext] The scenarios of GEMM for u/int8 data

 

 

On 2020-12-11 3:34 a.m., Linjie Yu wrote:

Hiall

 

Recently, I optimized the kernel of GEMM for int8 data.

Can we see the git of your work?

I found that there was no good solution to do in by the use of the present vector ISA.

The mainly difficult I meet is: The accumulator is 32bits, it needs wide 4 times(vqmacc or vwmul + vwmacc or vwmul + vwadd), which makes the registers are not enough to use.

Does this mean the 32 vector registers are not enough,

or that the number of elements for the given input vector length are not enough?

 

There are 2 different ways I used to optimize it by the present vector ISA.

1.     vdot.vv+vredsum.vs  (the tail process is so complex)

2.     vwmul + vwredsum.vs (vwredsum.vs used in the for loop)

Note vdot.vv is experimental. It is not planned for the v1.0 ratification proposal.

 

For solving this, I come up with a new instruction, call vwredsum.vs(new)

Unlike the old vwredsum.vs, the result is put at the first element, the new one can put the result in any position by index. It can be used like this: vwredsum.vs v2, v1, v1, #2

With a "temporary working vector" this new instruction is a combination of the old with any "insert scalar into element" instruction [such as  vrgather.vv splatt with mask ].

 

But they are all not good enough. Does someone have better solution?

I would be happy to look at your current work to make suggestions if you could direct me to the code.

 

Yours

Damon

Join tech-vector-ext@lists.riscv.org to automatically receive all group messages.