diff --git a/Python/jit.c b/Python/jit.c index 33320761621c4c5..d0fe7cd98a417a6 100644 --- a/Python/jit.c +++ b/Python/jit.c @@ -3,6 +3,7 @@ #include "Python.h" #include "pycore_abstract.h" +#include "pycore_bitutils.h" #include "pycore_call.h" #include "pycore_ceval.h" #include "pycore_critical_section.h" @@ -390,8 +391,70 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value) patch_32r(location, value); } +void patch_aarch64_trampoline(unsigned char *location, int ordinal); + #include "jit_stencils.h" +typedef struct { + void *mem; + SymbolMask mask; + size_t size; +} TrampolineState; + +//TODO: remove as global variable +TrampolineState trampoline_state; + +#if defined(__aarch64__) || defined(_M_ARM64) + #define TRAMPOLINE_SIZE 16 +#else + #define TRAMPOLINE_SIZE 0 +#endif + +// Generate and patch AArch64 trampolines. The symbols to jump to are stored +// in the jit_stencils.h in the symbols_map. +void +patch_aarch64_trampoline(unsigned char *location, int ordinal) +{ + // Masking is done modulo 32 as the mask is stored as an array of uint32_t + const uint32_t symbol_mask = 1 << (ordinal % 32); + const uint32_t trampoline_mask = trampoline_state.mask[ordinal / 32]; + assert(symbol_mask & trampoline_mask); + + // Count the number of set bits in the trampoline mask lower than ordinal, + // this gives the index into the array of trampolines. + int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1)); + for (int i = 0; i < ordinal / 32; i++) { + index += _Py_popcount32(trampoline_state.mask[i]); + } + + uint32_t *p = trampoline_state.mem + index * TRAMPOLINE_SIZE; + assert((size_t)index * TRAMPOLINE_SIZE < trampoline_state.size); + + uintptr_t value = (uintptr_t)symbols_map[ordinal]; + + /* Generate the trampoline + 0: 58000048 ldr x8, 8 + 4: d61f0100 br x8 + 8: 00000000 // The next two words contain the 64-bit address to jump to. + c: 00000000 + */ + p[0] = 0x58000048; + p[1] = 0xD61F0100; + p[2] = value & 0xffffffff; + p[3] = value >> 32; + + patch_aarch64_26r(location, (uintptr_t)p); +} + +static void +combine_symbol_mask(const SymbolMask src, SymbolMask dest, size_t size) +{ + // Calculate the union of the trampolines required by each StencilGroup + for (size_t i = 0; i < size; i++) { + dest[i] |= src[i]; + } +} + // Compiles executor in-place. Don't forget to call _PyJIT_Free later! int _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length) @@ -401,6 +464,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH]; size_t code_size = 0; size_t data_size = 0; + trampoline_state = (TrampolineState){}; group = &trampoline; code_size += group->code_size; data_size += group->data_size; @@ -410,15 +474,25 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz instruction_starts[i] = code_size; code_size += group->code_size; data_size += group->data_size; + combine_symbol_mask(group->trampoline_mask, + trampoline_state.mask, + Py_ARRAY_LENGTH(trampoline_state.mask)); } group = &stencil_groups[_FATAL_ERROR]; code_size += group->code_size; data_size += group->data_size; + combine_symbol_mask(group->trampoline_mask, + trampoline_state.mask, + Py_ARRAY_LENGTH(trampoline_state.mask)); + // Calculate the size of the trampolines required by the whole trace + for (size_t i = 0; i < Py_ARRAY_LENGTH(trampoline_state.mask); i++) { + trampoline_state.size += _Py_popcount32(trampoline_state.mask[i]) * TRAMPOLINE_SIZE; + } // Round up to the nearest page: size_t page_size = get_page_size(); assert((page_size & (page_size - 1)) == 0); - size_t padding = page_size - ((code_size + data_size) & (page_size - 1)); - size_t total_size = code_size + data_size + padding; + size_t padding = page_size - ((code_size + data_size + trampoline_state.size) & (page_size - 1)); + size_t total_size = code_size + data_size + trampoline_state.size + padding; unsigned char *memory = jit_alloc(total_size); if (memory == NULL) { return -1; @@ -430,6 +504,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz // Loop again to emit the code: unsigned char *code = memory; unsigned char *data = memory + code_size; + trampoline_state.mem = memory + code_size + data_size; // Compile the trampoline, which handles converting between the native // calling convention and the calling convention used by jitted code // (which may be different for efficiency reasons). On platforms where diff --git a/Tools/jit/_stencils.py b/Tools/jit/_stencils.py index 1c6a9edb39840d8..33e8c0e95cfc4cf 100644 --- a/Tools/jit/_stencils.py +++ b/Tools/jit/_stencils.py @@ -2,11 +2,15 @@ import dataclasses import enum -import sys import typing import _schema +# Number of 32-bit words needed to store the bit mask of external symbols +SYMBOL_MASK_SIZE: int = 4 + +known_symbols: dict[str | None, int] = {} + @enum.unique class HoleValue(enum.Enum): @@ -157,7 +161,7 @@ def as_c(self, where: str) -> str: if value: value += " + " value += f"(uintptr_t)&{self.symbol}" - if _signed(self.addend): + if _signed(self.addend) or not value: if value: value += " + " value += f"{_signed(self.addend):#x}" @@ -175,7 +179,6 @@ class Stencil: body: bytearray = dataclasses.field(default_factory=bytearray, init=False) holes: list[Hole] = dataclasses.field(default_factory=list, init=False) disassembly: list[str] = dataclasses.field(default_factory=list, init=False) - trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False) def pad(self, alignment: int) -> None: """Pad the stencil to the given alignment.""" @@ -184,39 +187,6 @@ def pad(self, alignment: int) -> None: self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}") self.body.extend([0] * padding) - def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole: - """Even with the large code model, AArch64 Linux insists on 28-bit jumps.""" - assert hole.symbol is not None - reuse_trampoline = hole.symbol in self.trampolines - if reuse_trampoline: - # Re-use the base address of the previously created trampoline - base = self.trampolines[hole.symbol] - else: - self.pad(alignment) - base = len(self.body) - new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA) - - if reuse_trampoline: - return new_hole - - self.disassembly += [ - f"{base + 4 * 0:x}: 58000048 ldr x8, 8", - f"{base + 4 * 1:x}: d61f0100 br x8", - f"{base + 4 * 2:x}: 00000000", - f"{base + 4 * 2:016x}: R_AARCH64_ABS64 {hole.symbol}", - f"{base + 4 * 3:x}: 00000000", - ] - for code in [ - 0x58000048.to_bytes(4, sys.byteorder), - 0xD61F0100.to_bytes(4, sys.byteorder), - 0x00000000.to_bytes(4, sys.byteorder), - 0x00000000.to_bytes(4, sys.byteorder), - ]: - self.body.extend(code) - self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64")) - self.trampolines[hole.symbol] = base - return new_hole - def remove_jump(self, *, alignment: int = 1) -> None: """Remove a zero-length continuation jump, if it exists.""" hole = max(self.holes, key=lambda hole: hole.offset) @@ -282,6 +252,7 @@ class StencilGroup: default_factory=dict, init=False ) _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False) + trampolines: set[int] = dataclasses.field(default_factory=set, init=False) def process_relocations(self, *, alignment: int = 1) -> None: """Fix up all GOT and internal relocations for this stencil group.""" @@ -291,9 +262,15 @@ def process_relocations(self, *, alignment: int = 1) -> None: in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"} and hole.value is HoleValue.ZERO ): - new_hole = self.data.emit_aarch64_trampoline(hole, alignment) - self.code.holes.remove(hole) - self.code.holes.append(new_hole) + hole.func = "patch_aarch64_trampoline" + if hole.symbol in known_symbols: + ordinal = known_symbols[hole.symbol] + else: + ordinal = len(known_symbols) + known_symbols[hole.symbol] = ordinal + self.trampolines.add(ordinal) + hole.addend = ordinal + hole.symbol = None self.code.remove_jump(alignment=alignment) self.code.pad(alignment) self.data.pad(8) @@ -348,9 +325,21 @@ def _emit_global_offset_table(self) -> None: ) self.data.body.extend([0] * 8) + def _get_trampoline_mask(self) -> str: + bitmask: int = 0 + trampoline_mask: list[str] = [] + for ordinal in self.trampolines: + bitmask |= 1 << ordinal + if bitmask: + trampoline_mask = [ + f"0x{(bitmask >> i*32) & ((1 << 32) - 1):x}" + for i in range(0, SYMBOL_MASK_SIZE) + ] + return ", ".join(trampoline_mask) + def as_c(self, opname: str) -> str: """Dump this hole as a StencilGroup initializer.""" - return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}" + return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {{{self._get_trampoline_mask()}}}}}" def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]: diff --git a/Tools/jit/_writer.py b/Tools/jit/_writer.py index 9d11094f85c7fff..d39a3d19e0ba895 100644 --- a/Tools/jit/_writer.py +++ b/Tools/jit/_writer.py @@ -7,12 +7,15 @@ def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]: + yield f"typedef uint32_t SymbolMask[{_stencils.SYMBOL_MASK_SIZE}];" + yield "" yield "typedef struct {" yield " void (*emit)(" yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor," yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[]);" yield " size_t code_size;" yield " size_t data_size;" + yield " SymbolMask trampoline_mask;" yield "} StencilGroup;" yield "" yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};" @@ -23,6 +26,11 @@ def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[s continue yield f" [{opname}] = {group.as_c(opname)}," yield "};" + yield "" + yield "static const void * const symbols_map[] = {" + for symbol, ordinal in _stencils.known_symbols.items(): + yield f" [{ordinal}] = &{symbol}," + yield "};" def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator[str]: