fixed TILELOADD/TILESTORED and TDPBF16PS

This commit is contained in:
Shwartsman 2024-01-11 21:04:14 +02:00
parent 8dca1e0e07
commit 9e36971e0f

View File

@ -122,7 +122,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILELOADD_TnnnMdq(bxInstruction_c *i)
}
unsigned elements_per_row = bytes_per_row / 4;
Bit32u mask = (elements_per_row < 16) ? (BX_CONST64(1) << elements_per_row) - 1 : BX_CONST64(0xFFFF);
Bit32u mask = (elements_per_row < 16) ? ((1 << elements_per_row) - 1) : 0xFFFF;
BX_CPU_THIS_PTR amx->set_tile_used(tile);
@ -130,6 +130,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILELOADD_TnnnMdq(bxInstruction_c *i)
Bit64u start_eaddr = BX_READ_64BIT_REG(i->sibBase()) + (Bit64s) i->displ32s();
Bit64u stride = BX_READ_64BIT_REG(i->sibIndex()) << i->sibScale();
i->setVL(BX_VL512);
for (unsigned row=BX_CPU_THIS_PTR amx->start_row; row < rows; row++) {
BxPackedAvxRegister *data = &(BX_CPU_THIS_PTR amx->tile[tile].row[row]);
@ -181,14 +182,14 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILESTORED_MdqTnnn(bxInstruction_c *i)
}
unsigned elements_per_row = bytes_per_row / 4;
Bit32u mask = (elements_per_row < 16) ? (BX_CONST64(1) << elements_per_row) - 1 : BX_CONST64(0xFFFF);
Bit32u mask = (elements_per_row < 16) ? ((1 << elements_per_row) - 1) : 0xFFFF;
i->setVL(BX_VL512);
Bit64u start_eaddr = BX_READ_64BIT_REG(i->sibBase()) + (Bit64s) i->displ32s();
Bit64u stride = BX_READ_64BIT_REG(i->sibIndex()) << i->sibScale();
for (unsigned row=BX_CPU_THIS_PTR amx->start_row; row < rows; row++) {
BxPackedAvxRegister *data = &(BX_CPU_THIS_PTR amx->tile[tile].row[row]);
Bit64u eaddr = start_eaddr + row * stride;
if (bytes_per_row == 64)
write_linear_zmmword(i->seg(), get_laddr64(i->seg(), eaddr), data);
@ -484,15 +485,15 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBF16PS_TnnnTrmTreg(bxInstruction_c *i)
for (unsigned m=0; m < max_m; m++) {
float32 tmp[32]; // new empty array
for (unsigned n=0; n < 32; n++) tmp[32] = 0;
for (unsigned n=0; n < 32; n++) tmp[n] = 0;
for (unsigned k=0; k < max_k; k++) {
for (unsigned n=0; n < max_n; n++) {
tmp[2*n] = float32_fmadd(convert_bfloat16_to_fp32(tsrc1->row[m].vmm16u(2*k)),
convert_bfloat16_to_fp32(tsrc2->row[k].vmm16u(2*n)), tmp[2*n], status);
convert_bfloat16_to_fp32(tsrc2->row[k].vmm16u(2*n)), tmp[2*n], status);
tmp[2*n+1] = float32_fmadd(convert_bfloat16_to_fp32(tsrc1->row[m].vmm16u(2*k+1)),
convert_bfloat16_to_fp32(tsrc2->row[k].vmm16u(2*n+1)), tmp[2*n+1], status);
convert_bfloat16_to_fp32(tsrc2->row[k].vmm16u(2*n+1)), tmp[2*n+1], status);
}
}