correct implementation of AMX TDPBxxD instructions
This commit is contained in:
parent
b8054277cb
commit
8dca1e0e07
@ -275,47 +275,189 @@ void BX_CPU_C::check_tiles(bxInstruction_c *i, unsigned tile_dst, unsigned tile_
|
||||
}
|
||||
}
|
||||
|
||||
#include "cpu/simd_vnni.h"
|
||||
BX_CPP_INLINE Bit32u DPBDSS(Bit32u x, Bit32u y)
|
||||
{
|
||||
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
|
||||
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
|
||||
|
||||
#define HANDLE_AMX_INT8_3OP(HANDLER, func) \
|
||||
void BX_CPP_AttrRegparmN(1) BX_CPU_C:: HANDLER (bxInstruction_c *i) \
|
||||
{ \
|
||||
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2(); \
|
||||
check_tiles(i, tile_dst, tile_src1, tile_src2); \
|
||||
\
|
||||
/* R C */ \
|
||||
/* A = m x k (tsrc1) */ \
|
||||
/* B = k x n (tsrc2) */ \
|
||||
/* C = m x n (tsrcdest) */ \
|
||||
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; \
|
||||
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); \
|
||||
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); \
|
||||
\
|
||||
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]); \
|
||||
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]); \
|
||||
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]); \
|
||||
\
|
||||
for (unsigned m=0; m < max_m; m++) { \
|
||||
BxPackedAvxRegister* tmp = &(tdst->row[m]); \
|
||||
for (unsigned k=0; k < max_k; k++) { \
|
||||
for (unsigned n=0; n < max_n; n+=4) { \
|
||||
(func)(&(tmp->vmm128(n/4)), \
|
||||
&(tsrc1->row[m].vmm128(n/4)), &(tsrc2->row[m].vmm128(n/4))); \
|
||||
} \
|
||||
} \
|
||||
tdst->zero_upper_row_data32(m, max_n); \
|
||||
} \
|
||||
\
|
||||
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst); \
|
||||
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m); \
|
||||
BX_CPU_THIS_PTR amx->restart(); \
|
||||
BX_NEXT_INSTR(i); \
|
||||
Bit32s p0dword = Bit32s(xbyte[0]) * Bit32s(ybyte[0]);
|
||||
Bit32s p1dword = Bit32s(xbyte[1]) * Bit32s(ybyte[1]);
|
||||
Bit32s p2dword = Bit32s(xbyte[2]) * Bit32s(ybyte[2]);
|
||||
Bit32s p3dword = Bit32s(xbyte[3]) * Bit32s(ybyte[3]);
|
||||
|
||||
return p0dword + p1dword + p2dword + p3dword;
|
||||
}
|
||||
|
||||
BX_CPP_INLINE Bit32u DPBDSU(Bit32u x, Bit32u y)
|
||||
{
|
||||
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
|
||||
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
|
||||
|
||||
Bit32s p0dword = Bit32s(xbyte[0]) * Bit32u(ybyte[0]);
|
||||
Bit32s p1dword = Bit32s(xbyte[1]) * Bit32u(ybyte[1]);
|
||||
Bit32s p2dword = Bit32s(xbyte[2]) * Bit32u(ybyte[2]);
|
||||
Bit32s p3dword = Bit32s(xbyte[3]) * Bit32u(ybyte[3]);
|
||||
|
||||
return p0dword + p1dword + p2dword + p3dword;
|
||||
}
|
||||
|
||||
BX_CPP_INLINE Bit32u DPBDUS(Bit32u x, Bit32u y)
|
||||
{
|
||||
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
|
||||
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
|
||||
|
||||
Bit32s p0dword = Bit32u(xbyte[0]) * Bit32s(ybyte[0]);
|
||||
Bit32s p1dword = Bit32u(xbyte[1]) * Bit32s(ybyte[1]);
|
||||
Bit32s p2dword = Bit32u(xbyte[2]) * Bit32s(ybyte[2]);
|
||||
Bit32s p3dword = Bit32u(xbyte[3]) * Bit32s(ybyte[3]);
|
||||
|
||||
return p0dword + p1dword + p2dword + p3dword;
|
||||
}
|
||||
|
||||
BX_CPP_INLINE Bit32u DPBDUU(Bit32u x, Bit32u y)
|
||||
{
|
||||
const Bit8u xbyte[4] = { Bit8u(x & 0xff), Bit8u((x >> 8) & 0xff), Bit8u((x >> 16) & 0xff), Bit8u(x >> 24) };
|
||||
const Bit8u ybyte[4] = { Bit8u(y & 0xff), Bit8u((y >> 8) & 0xff), Bit8u((y >> 16) & 0xff), Bit8u(y >> 24) };
|
||||
|
||||
Bit32u p0dword = Bit32u(xbyte[0]) * Bit32u(ybyte[0]);
|
||||
Bit32u p1dword = Bit32u(xbyte[1]) * Bit32u(ybyte[1]);
|
||||
Bit32u p2dword = Bit32u(xbyte[2]) * Bit32u(ybyte[2]);
|
||||
Bit32u p3dword = Bit32u(xbyte[3]) * Bit32u(ybyte[3]);
|
||||
|
||||
return p0dword + p1dword + p2dword + p3dword;
|
||||
}
|
||||
|
||||
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSSD_TnnnTrmTreg(bxInstruction_c *i)
|
||||
{
|
||||
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
|
||||
check_tiles(i, tile_dst, tile_src1, tile_src2);
|
||||
|
||||
/* R C */
|
||||
/* A = m x k (tsrc1) */
|
||||
/* B = k x n (tsrc2) */
|
||||
/* C = m x n (tsrcdest) */
|
||||
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
|
||||
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
|
||||
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
|
||||
|
||||
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
|
||||
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
|
||||
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
|
||||
|
||||
for (unsigned m=0; m < max_m; m++) {
|
||||
BxPackedAvxRegister* tmp = &(tdst->row[m]);
|
||||
for (unsigned k=0; k < max_k; k++) {
|
||||
for (unsigned n=0; n < max_n; n++) {
|
||||
tmp->vmm32s(n) += DPBDSS(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
|
||||
}
|
||||
}
|
||||
tdst->zero_upper_row_data32(m, max_n);
|
||||
}
|
||||
|
||||
HANDLE_AMX_INT8_3OP(TDPBSSD_TnnnTrmTreg, xmm_pdpbssd)
|
||||
HANDLE_AMX_INT8_3OP(TDPBSUD_TnnnTrmTreg, xmm_pdpbsud)
|
||||
HANDLE_AMX_INT8_3OP(TDPBUSD_TnnnTrmTreg, xmm_pdpbusd)
|
||||
HANDLE_AMX_INT8_3OP(TDPBUUD_TnnnTrmTreg, xmm_pdpbuud)
|
||||
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
|
||||
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
|
||||
BX_CPU_THIS_PTR amx->restart();
|
||||
BX_NEXT_INSTR(i);
|
||||
}
|
||||
|
||||
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSUD_TnnnTrmTreg(bxInstruction_c *i)
|
||||
{
|
||||
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
|
||||
check_tiles(i, tile_dst, tile_src1, tile_src2);
|
||||
|
||||
/* R C */
|
||||
/* A = m x k (tsrc1) */
|
||||
/* B = k x n (tsrc2) */
|
||||
/* C = m x n (tsrcdest) */
|
||||
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
|
||||
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
|
||||
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
|
||||
|
||||
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
|
||||
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
|
||||
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
|
||||
|
||||
for (unsigned m=0; m < max_m; m++) {
|
||||
BxPackedAvxRegister* tmp = &(tdst->row[m]);
|
||||
for (unsigned k=0; k < max_k; k++) {
|
||||
for (unsigned n=0; n < max_n; n++) {
|
||||
tmp->vmm32s(n) += DPBDSU(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
|
||||
}
|
||||
}
|
||||
tdst->zero_upper_row_data32(m, max_n);
|
||||
}
|
||||
|
||||
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
|
||||
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
|
||||
BX_CPU_THIS_PTR amx->restart();
|
||||
BX_NEXT_INSTR(i);
|
||||
}
|
||||
|
||||
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUSD_TnnnTrmTreg(bxInstruction_c *i)
|
||||
{
|
||||
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
|
||||
check_tiles(i, tile_dst, tile_src1, tile_src2);
|
||||
|
||||
/* R C */
|
||||
/* A = m x k (tsrc1) */
|
||||
/* B = k x n (tsrc2) */
|
||||
/* C = m x n (tsrcdest) */
|
||||
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
|
||||
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
|
||||
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
|
||||
|
||||
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
|
||||
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
|
||||
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
|
||||
|
||||
for (unsigned m=0; m < max_m; m++) {
|
||||
BxPackedAvxRegister* tmp = &(tdst->row[m]);
|
||||
for (unsigned k=0; k < max_k; k++) {
|
||||
for (unsigned n=0; n < max_n; n++) {
|
||||
tmp->vmm32s(n) += DPBDUS(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
|
||||
}
|
||||
}
|
||||
tdst->zero_upper_row_data32(m, max_n);
|
||||
}
|
||||
|
||||
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
|
||||
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
|
||||
BX_CPU_THIS_PTR amx->restart();
|
||||
BX_NEXT_INSTR(i);
|
||||
}
|
||||
|
||||
void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUUD_TnnnTrmTreg(bxInstruction_c *i)
|
||||
{
|
||||
unsigned tile_dst = i->dst(), tile_src1 = i->src1(), tile_src2 = i->src2();
|
||||
check_tiles(i, tile_dst, tile_src1, tile_src2);
|
||||
|
||||
/* R C */
|
||||
/* A = m x k (tsrc1) */
|
||||
/* B = k x n (tsrc2) */
|
||||
/* C = m x n (tsrcdest) */
|
||||
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
|
||||
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
|
||||
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
|
||||
|
||||
AMX::TILE *tdst = &(BX_CPU_THIS_PTR amx->tile[tile_dst]);
|
||||
AMX::TILE *tsrc1 = &(BX_CPU_THIS_PTR amx->tile[tile_src1]);
|
||||
AMX::TILE *tsrc2 = &(BX_CPU_THIS_PTR amx->tile[tile_src2]);
|
||||
|
||||
for (unsigned m=0; m < max_m; m++) {
|
||||
BxPackedAvxRegister* tmp = &(tdst->row[m]);
|
||||
for (unsigned k=0; k < max_k; k++) {
|
||||
for (unsigned n=0; n < max_n; n++) {
|
||||
tmp->vmm32u(n) += DPBDUU(tsrc1->row[m].vmm32u(k), tsrc2->row[k].vmm32u(n));
|
||||
}
|
||||
}
|
||||
tdst->zero_upper_row_data32(m, max_n);
|
||||
}
|
||||
|
||||
BX_CPU_THIS_PTR amx->set_tile_used(tile_dst);
|
||||
BX_CPU_THIS_PTR amx->tile[tile_dst].clear_upper_rows(max_m);
|
||||
BX_CPU_THIS_PTR amx->restart();
|
||||
BX_NEXT_INSTR(i);
|
||||
}
|
||||
|
||||
#include "bf16.h"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user