correct implementation of AMX TDPBxxD instructions

This commit is contained in:
Shwartsman 2024-01-11 20:20:04 +02:00
parent b8054277cb
commit 8dca1e0e07

View File

@ -275,47 +275,189 @@ void BX_CPU_C::check_tiles(bxInstruction_c *i, unsigned tile_dst, unsigned tile_
}
}
#include "cpu/simd_vnni.h"
BX_CPP_INLINE Bit32u DPBDSS(Bit32u x, Bit32u y)
{
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
#define HANDLE_AMX_INT8_3OP(HANDLER, func) \
void BX_CPP_AttrRegparmN(1) BX_CPU_C:: HANDLER (bxInstruction_c *i) \
{ \
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2(); \
check_tiles(i, tile_dst, tile_src1, tile_src2); \
\
/* R C */ \
/* A = m x k (tsrc1) */ \
/* B = k x n (tsrc2) */ \
/* C = m x n (tsrcdest) */ \
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; \
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); \
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); \
\
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]); \
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]); \
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]); \
\
for (unsigned m=0; m < max_m; m++) { \
BxPackedAvxRegister* tmp = &(tdst->row[m]); \
for (unsigned k=0; k < max_k; k++) { \
for (unsigned n=0; n < max_n; n+=4) { \
(func)(&(tmp->vmm128(n/4)), \
&(tsrc1->row[m].vmm128(n/4)), &(tsrc2->row[m].vmm128(n/4))); \
} \
} \
tdst->zero_upper_row_data32(m, max_n); \
} \
\
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst); \
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m); \
BX_CPU_THIS_PTR amx->restart(); \
BX_NEXT_INSTR(i); \
Bit32s p0dword = Bit32s(xbyte[0]) * Bit32s(ybyte[0]);
Bit32s p1dword = Bit32s(xbyte[1]) * Bit32s(ybyte[1]);
Bit32s p2dword = Bit32s(xbyte[2]) * Bit32s(ybyte[2]);
Bit32s p3dword = Bit32s(xbyte[3]) * Bit32s(ybyte[3]);
return p0dword + p1dword + p2dword + p3dword;
}
BX_CPP_INLINE Bit32u DPBDSU(Bit32u x, Bit32u y)
{
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
Bit32s p0dword = Bit32s(xbyte[0]) * Bit32u(ybyte[0]);
Bit32s p1dword = Bit32s(xbyte[1]) * Bit32u(ybyte[1]);
Bit32s p2dword = Bit32s(xbyte[2]) * Bit32u(ybyte[2]);
Bit32s p3dword = Bit32s(xbyte[3]) * Bit32u(ybyte[3]);
return p0dword + p1dword + p2dword + p3dword;
}
BX_CPP_INLINE Bit32u DPBDUS(Bit32u x, Bit32u y)
{
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
Bit32s p0dword = Bit32u(xbyte[0]) * Bit32s(ybyte[0]);
Bit32s p1dword = Bit32u(xbyte[1]) * Bit32s(ybyte[1]);
Bit32s p2dword = Bit32u(xbyte[2]) * Bit32s(ybyte[2]);
Bit32s p3dword = Bit32u(xbyte[3]) * Bit32s(ybyte[3]);
return p0dword + p1dword + p2dword + p3dword;
}
BX_CPP_INLINE Bit32u DPBDUU(Bit32u x, Bit32u y)
{
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
Bit32u p0dword = Bit32u(xbyte[0]) * Bit32u(ybyte[0]);
Bit32u p1dword = Bit32u(xbyte[1]) * Bit32u(ybyte[1]);
Bit32u p2dword = Bit32u(xbyte[2]) * Bit32u(ybyte[2]);
Bit32u p3dword = Bit32u(xbyte[3]) * Bit32u(ybyte[3]);
return p0dword + p1dword + p2dword + p3dword;
}
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSSD_TnnnTrmTreg(bxInstruction_c *i)
{
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
check_tiles(i, tile_dst, tile_src1, tile_src2);
/* R C */
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
for (unsigned m=0; m < max_m; m++) {
BxPackedAvxRegister* tmp = &(tdst->row[m]);
for (unsigned k=0; k < max_k; k++) {
for (unsigned n=0; n < max_n; n++) {
tmp->vmm32s(n) += DPBDSS(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
}
}
tdst->zero_upper_row_data32(m, max_n);
}
HANDLE_AMX_INT8_3OP(TDPBSSD_TnnnTrmTreg, xmm_pdpbssd)
HANDLE_AMX_INT8_3OP(TDPBSUD_TnnnTrmTreg, xmm_pdpbsud)
HANDLE_AMX_INT8_3OP(TDPBUSD_TnnnTrmTreg, xmm_pdpbusd)
HANDLE_AMX_INT8_3OP(TDPBUUD_TnnnTrmTreg, xmm_pdpbuud)
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
BX_CPU_THIS_PTR amx->restart();
BX_NEXT_INSTR(i);
}
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSUD_TnnnTrmTreg(bxInstruction_c *i)
{
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
check_tiles(i, tile_dst, tile_src1, tile_src2);
/* R C */
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
for (unsigned m=0; m < max_m; m++) {
BxPackedAvxRegister* tmp = &(tdst->row[m]);
for (unsigned k=0; k < max_k; k++) {
for (unsigned n=0; n < max_n; n++) {
tmp->vmm32s(n) += DPBDSU(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
}
}
tdst->zero_upper_row_data32(m, max_n);
}
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
BX_CPU_THIS_PTR amx->restart();
BX_NEXT_INSTR(i);
}
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUSD_TnnnTrmTreg(bxInstruction_c *i)
{
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
check_tiles(i, tile_dst, tile_src1, tile_src2);
/* R C */
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
for (unsigned m=0; m < max_m; m++) {
BxPackedAvxRegister* tmp = &(tdst->row[m]);
for (unsigned k=0; k < max_k; k++) {
for (unsigned n=0; n < max_n; n++) {
tmp->vmm32s(n) += DPBDUS(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
}
}
tdst->zero_upper_row_data32(m, max_n);
}
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
BX_CPU_THIS_PTR amx->restart();
BX_NEXT_INSTR(i);
}
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUUD_TnnnTrmTreg(bxInstruction_c *i)
{
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
check_tiles(i, tile_dst, tile_src1, tile_src2);
/* R C */
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
for (unsigned m=0; m < max_m; m++) {
BxPackedAvxRegister* tmp = &(tdst->row[m]);
for (unsigned k=0; k < max_k; k++) {
for (unsigned n=0; n < max_n; n++) {
tmp->vmm32u(n) += DPBDUU(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
}
}
tdst->zero_upper_row_data32(m, max_n);
}
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
BX_CPU_THIS_PTR amx->restart();
BX_NEXT_INSTR(i);
}
#include "bf16.h"