diff --git a/src/common/aarch64/oaknut_abi.h b/src/common/aarch64/oaknut_abi.h index 7323cfca49..025998bb5c 100644 --- a/src/common/aarch64/oaknut_abi.h +++ b/src/common/aarch64/oaknut_abi.h @@ -89,22 +89,45 @@ inline void ABI_PushRegisters(oaknut::CodeGenerator& code, std::bitset<64> regs, code.SUB(SP, SP, frame_info.subtraction); } - // TODO(wunk): Push pairs of registers at a time with STP - std::size_t offset = 0; - for (std::size_t i = 0; i < 32; ++i) { - if (regs[i] && ABI_ALL_GPRS[i]) { - const XReg reg = IndexToXReg(i); - code.STR(reg, SP, offset); - offset += 8; + { + const std::bitset<64> gprs_mask = (regs & ABI_ALL_GPRS); + std::vector gprs; + gprs.reserve(32); + for (u8 i = 0; i < 32; ++i) { + if (gprs_mask.test(i)) { + gprs.emplace_back(IndexToXReg(i)); + } + } + + if (!gprs.empty()) { + for (size_t i = 0; i < gprs.size() - 1; i += 2) { + code.STP(gprs[i], gprs[i + 1], SP, i * sizeof(u64)); + } + if (gprs.size() % 2 == 1) { + const size_t i = gprs.size() - 1; + code.STR(gprs[i], SP, i * sizeof(u64)); + } } } - offset = 0; - for (std::size_t i = 32; i < 64; ++i) { - if (regs[i] && ABI_ALL_FPRS[i]) { - const VReg reg = IndexToVReg(i); - code.STR(reg.toQ(), SP, u16(frame_info.fprs_offset + offset)); - offset += 16; + { + const std::bitset<64> fprs_mask = (regs & ABI_ALL_FPRS); + std::vector fprs; + fprs.reserve(32); + for (u8 i = 32; i < 64; ++i) { + if (fprs_mask.test(i)) { + fprs.emplace_back(IndexToVReg(i).toQ()); + } + } + + if (!fprs.empty()) { + for (size_t i = 0; i < fprs.size() - 1; i += 2) { + code.STP(fprs[i], fprs[i + 1], SP, frame_info.fprs_offset + i * (sizeof(u64) * 2)); + } + if (fprs.size() % 2 == 1) { + const size_t i = fprs.size() - 1; + code.STR(fprs[i], SP, frame_info.fprs_offset + i * (sizeof(u64) * 2)); + } } } @@ -125,22 +148,45 @@ inline void ABI_PopRegisters(oaknut::CodeGenerator& code, std::bitset<64> regs, code.ADD(SP, SP, frame_size); } - // TODO(wunk): Pop pairs of registers at a time with LDP - std::size_t offset = 0; - for (std::size_t i = 0; i < 32; ++i) { - if (regs[i] && ABI_ALL_GPRS[i]) { - const XReg reg = IndexToXReg(i); - code.LDR(reg, SP, offset); - offset += 8; + { + const std::bitset<64> gprs_mask = (regs & ABI_ALL_GPRS); + std::vector gprs; + gprs.reserve(32); + for (u8 i = 0; i < 32; ++i) { + if (gprs_mask.test(i)) { + gprs.emplace_back(IndexToXReg(i)); + } + } + + if (!gprs.empty()) { + for (size_t i = 0; i < gprs.size() - 1; i += 2) { + code.LDP(gprs[i], gprs[i + 1], SP, i * sizeof(u64)); + } + if (gprs.size() % 2 == 1) { + const size_t i = gprs.size() - 1; + code.LDR(gprs[i], SP, i * sizeof(u64)); + } } } - offset = 0; - for (std::size_t i = 32; i < 64; ++i) { - if (regs[i] && ABI_ALL_FPRS[i]) { - const VReg reg = IndexToVReg(i); - code.LDR(reg.toQ(), SP, frame_info.fprs_offset + offset); - offset += 16; + { + const std::bitset<64> fprs_mask = (regs & ABI_ALL_FPRS); + std::vector fprs; + fprs.reserve(32); + for (u8 i = 32; i < 64; ++i) { + if (fprs_mask.test(i)) { + fprs.emplace_back(IndexToVReg(i).toQ()); + } + } + + if (!fprs.empty()) { + for (size_t i = 0; i < fprs.size() - 1; i += 2) { + code.LDP(fprs[i], fprs[i + 1], SP, frame_info.fprs_offset + i * (sizeof(u64) * 2)); + } + if (fprs.size() % 2 == 1) { + const size_t i = fprs.size() - 1; + code.LDR(fprs[i], SP, frame_info.fprs_offset + i * (sizeof(u64) * 2)); + } } }