Skip to content

Commit

Permalink
pythongh-119726: generate and patch AArch64 trampolines
Browse files Browse the repository at this point in the history
AArch64 trampolines are now generated at runtime at
the end of every trace.
  • Loading branch information
diegorusso committed Sep 9, 2024
1 parent 65fcaa3 commit 6d2ea89
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 42 deletions.
79 changes: 77 additions & 2 deletions Python/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "Python.h"

#include "pycore_abstract.h"
#include "pycore_bitutils.h"
#include "pycore_call.h"
#include "pycore_ceval.h"
#include "pycore_critical_section.h"
Expand Down Expand Up @@ -390,8 +391,70 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value)
patch_32r(location, value);
}

void patch_aarch64_trampoline(unsigned char *location, int ordinal);

#include "jit_stencils.h"

typedef struct {
void *mem;
SymbolMask mask;
size_t size;
} TrampolineState;

//TODO: remove as global variable
TrampolineState trampoline_state;

#if defined(__aarch64__) || defined(_M_ARM64)
#define TRAMPOLINE_SIZE 16
#else
#define TRAMPOLINE_SIZE 0
#endif

// Generate and patch AArch64 trampolines. The symbols to jump to are stored
// in the jit_stencils.h in the symbols_map.
void
patch_aarch64_trampoline(unsigned char *location, int ordinal)
{
// Masking is done modulo 32 as the mask is stored as an array of uint32_t
const uint32_t symbol_mask = 1 << (ordinal % 32);
const uint32_t trampoline_mask = trampoline_state.mask[ordinal / 32];
assert(symbol_mask & trampoline_mask);

// Count the number of set bits in the trampoline mask lower than ordinal,
// this gives the index into the array of trampolines.
int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
for (int i = 0; i < ordinal / 32; i++) {
index += _Py_popcount32(trampoline_state.mask[i]);
}

uint32_t *p = trampoline_state.mem + index * TRAMPOLINE_SIZE;
assert((size_t)index * TRAMPOLINE_SIZE < trampoline_state.size);

uintptr_t value = (uintptr_t)symbols_map[ordinal];

/* Generate the trampoline
0: 58000048 ldr x8, 8
4: d61f0100 br x8
8: 00000000 // The next two words contain the 64-bit address to jump to.
c: 00000000
*/
p[0] = 0x58000048;
p[1] = 0xD61F0100;
p[2] = value & 0xffffffff;
p[3] = value >> 32;

patch_aarch64_26r(location, (uintptr_t)p);
}

static void
combine_symbol_mask(const SymbolMask src, SymbolMask dest, size_t size)
{
// Calculate the union of the trampolines required by each StencilGroup
for (size_t i = 0; i < size; i++) {
dest[i] |= src[i];
}
}

// Compiles executor in-place. Don't forget to call _PyJIT_Free later!
int
_PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length)
Expand All @@ -401,6 +464,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
size_t code_size = 0;
size_t data_size = 0;
trampoline_state = (TrampolineState){};
group = &trampoline;
code_size += group->code_size;
data_size += group->data_size;
Expand All @@ -410,15 +474,25 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
instruction_starts[i] = code_size;
code_size += group->code_size;
data_size += group->data_size;
combine_symbol_mask(group->trampoline_mask,
trampoline_state.mask,
Py_ARRAY_LENGTH(trampoline_state.mask));
}
group = &stencil_groups[_FATAL_ERROR];
code_size += group->code_size;
data_size += group->data_size;
combine_symbol_mask(group->trampoline_mask,
trampoline_state.mask,
Py_ARRAY_LENGTH(trampoline_state.mask));
// Calculate the size of the trampolines required by the whole trace
for (size_t i = 0; i < Py_ARRAY_LENGTH(trampoline_state.mask); i++) {
trampoline_state.size += _Py_popcount32(trampoline_state.mask[i]) * TRAMPOLINE_SIZE;
}
// Round up to the nearest page:
size_t page_size = get_page_size();
assert((page_size & (page_size - 1)) == 0);
size_t padding = page_size - ((code_size + data_size) & (page_size - 1));
size_t total_size = code_size + data_size + padding;
size_t padding = page_size - ((code_size + data_size + trampoline_state.size) & (page_size - 1));
size_t total_size = code_size + data_size + trampoline_state.size + padding;
unsigned char *memory = jit_alloc(total_size);
if (memory == NULL) {
return -1;
Expand All @@ -430,6 +504,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
// Loop again to emit the code:
unsigned char *code = memory;
unsigned char *data = memory + code_size;
trampoline_state.mem = memory + code_size + data_size;
// Compile the trampoline, which handles converting between the native
// calling convention and the calling convention used by jitted code
// (which may be different for efficiency reasons). On platforms where
Expand Down
69 changes: 29 additions & 40 deletions Tools/jit/_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import dataclasses
import enum
import sys
import typing

import _schema

# Number of 32-bit words needed to store the bit mask of external symbols
SYMBOL_MASK_SIZE: int = 4

known_symbols: dict[str | None, int] = {}


@enum.unique
class HoleValue(enum.Enum):
Expand Down Expand Up @@ -157,7 +161,7 @@ def as_c(self, where: str) -> str:
if value:
value += " + "
value += f"(uintptr_t)&{self.symbol}"
if _signed(self.addend):
if _signed(self.addend) or not value:
if value:
value += " + "
value += f"{_signed(self.addend):#x}"
Expand All @@ -175,7 +179,6 @@ class Stencil:
body: bytearray = dataclasses.field(default_factory=bytearray, init=False)
holes: list[Hole] = dataclasses.field(default_factory=list, init=False)
disassembly: list[str] = dataclasses.field(default_factory=list, init=False)
trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False)

def pad(self, alignment: int) -> None:
"""Pad the stencil to the given alignment."""
Expand All @@ -184,39 +187,6 @@ def pad(self, alignment: int) -> None:
self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
self.body.extend([0] * padding)

def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole:
"""Even with the large code model, AArch64 Linux insists on 28-bit jumps."""
assert hole.symbol is not None
reuse_trampoline = hole.symbol in self.trampolines
if reuse_trampoline:
# Re-use the base address of the previously created trampoline
base = self.trampolines[hole.symbol]
else:
self.pad(alignment)
base = len(self.body)
new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA)

if reuse_trampoline:
return new_hole

self.disassembly += [
f"{base + 4 * 0:x}: 58000048 ldr x8, 8",
f"{base + 4 * 1:x}: d61f0100 br x8",
f"{base + 4 * 2:x}: 00000000",
f"{base + 4 * 2:016x}: R_AARCH64_ABS64 {hole.symbol}",
f"{base + 4 * 3:x}: 00000000",
]
for code in [
0x58000048.to_bytes(4, sys.byteorder),
0xD61F0100.to_bytes(4, sys.byteorder),
0x00000000.to_bytes(4, sys.byteorder),
0x00000000.to_bytes(4, sys.byteorder),
]:
self.body.extend(code)
self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64"))
self.trampolines[hole.symbol] = base
return new_hole

def remove_jump(self, *, alignment: int = 1) -> None:
"""Remove a zero-length continuation jump, if it exists."""
hole = max(self.holes, key=lambda hole: hole.offset)
Expand Down Expand Up @@ -282,6 +252,7 @@ class StencilGroup:
default_factory=dict, init=False
)
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
trampolines: set[int] = dataclasses.field(default_factory=set, init=False)

def process_relocations(self, *, alignment: int = 1) -> None:
"""Fix up all GOT and internal relocations for this stencil group."""
Expand All @@ -291,9 +262,15 @@ def process_relocations(self, *, alignment: int = 1) -> None:
in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"}
and hole.value is HoleValue.ZERO
):
new_hole = self.data.emit_aarch64_trampoline(hole, alignment)
self.code.holes.remove(hole)
self.code.holes.append(new_hole)
hole.func = "patch_aarch64_trampoline"
if hole.symbol in known_symbols:
ordinal = known_symbols[hole.symbol]
else:
ordinal = len(known_symbols)
known_symbols[hole.symbol] = ordinal
self.trampolines.add(ordinal)
hole.addend = ordinal
hole.symbol = None
self.code.remove_jump(alignment=alignment)
self.code.pad(alignment)
self.data.pad(8)
Expand Down Expand Up @@ -348,9 +325,21 @@ def _emit_global_offset_table(self) -> None:
)
self.data.body.extend([0] * 8)

def _get_trampoline_mask(self) -> str:
bitmask: int = 0
trampoline_mask: list[str] = []
for ordinal in self.trampolines:
bitmask |= 1 << ordinal
if bitmask:
trampoline_mask = [
f"0x{(bitmask >> i*32) & ((1 << 32) - 1):x}"
for i in range(0, SYMBOL_MASK_SIZE)
]
return ", ".join(trampoline_mask)

def as_c(self, opname: str) -> str:
"""Dump this hole as a StencilGroup initializer."""
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}"
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {{{self._get_trampoline_mask()}}}}}"


def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
Expand Down
8 changes: 8 additions & 0 deletions Tools/jit/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@


def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
yield f"typedef uint32_t SymbolMask[{_stencils.SYMBOL_MASK_SIZE}];"
yield ""
yield "typedef struct {"
yield " void (*emit)("
yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor,"
yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[]);"
yield " size_t code_size;"
yield " size_t data_size;"
yield " SymbolMask trampoline_mask;"
yield "} StencilGroup;"
yield ""
yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};"
Expand All @@ -23,6 +26,11 @@ def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[s
continue
yield f" [{opname}] = {group.as_c(opname)},"
yield "};"
yield ""
yield "static const void * const symbols_map[] = {"
for symbol, ordinal in _stencils.known_symbols.items():
yield f" [{ordinal}] = &{symbol},"
yield "};"


def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator[str]:
Expand Down

0 comments on commit 6d2ea89

Please sign in to comment.