From b26525949f28a716844f7b5d6c2a1b72eaca5062 Mon Sep 17 00:00:00 2001 From: Heeseon Cheon Date: Tue, 5 Sep 2023 10:51:37 +0900 Subject: [PATCH] feat: add packet.py typing Co-authored-by: sean-k1 --- pymysqlreplication/packet.py | 255 ++++++++++++++++++----------------- 1 file changed, 133 insertions(+), 122 deletions(-) diff --git a/pymysqlreplication/packet.py b/pymysqlreplication/packet.py index 665caebe..ecc38607 100644 --- a/pymysqlreplication/packet.py +++ b/pymysqlreplication/packet.py @@ -4,6 +4,9 @@ from pymysqlreplication import constants, event, row_event +from typing import List, Tuple, Dict, Optional, Union, FrozenSet, Type +from pymysql.connections import MysqlPacket, Connection + # Constants from PyMYSQL source code NULL_COLUMN = 251 UNSIGNED_CHAR_COLUMN = 251 @@ -15,7 +18,6 @@ UNSIGNED_INT24_LENGTH = 3 UNSIGNED_INT64_LENGTH = 8 - JSONB_TYPE_SMALL_OBJECT = 0x0 JSONB_TYPE_LARGE_OBJECT = 0x1 JSONB_TYPE_SMALL_ARRAY = 0x2 @@ -36,24 +38,10 @@ JSONB_LITERAL_FALSE = 0x2 -def read_offset_or_inline(packet, large): - t = packet.read_uint8() - - if t in (JSONB_TYPE_LITERAL, - JSONB_TYPE_INT16, JSONB_TYPE_UINT16): - return (t, None, packet.read_binary_json_type_inlined(t, large)) - if large and t in (JSONB_TYPE_INT32, JSONB_TYPE_UINT32): - return (t, None, packet.read_binary_json_type_inlined(t, large)) - - if large: - return (t, packet.read_uint32(), None) - return (t, packet.read_uint16(), None) - - class BinLogPacketWrapper(object): """ - Bin Log Packet Wrapper. It uses an existing packet object, and wraps - around it, exposing useful variables while still providing access + Bin Log Packet Wrapper uses an existing packet object and wraps around it, + exposing useful variables while still providing access to the original packet objects variables and methods. """ @@ -83,7 +71,7 @@ class BinLogPacketWrapper(object): constants.DELETE_ROWS_EVENT_V2: row_event.DeleteRowsEvent, constants.TABLE_MAP_EVENT: row_event.TableMapEvent, - #5.6 GTID enabled replication events + # 5.6 GTID enabled replication events constants.ANONYMOUS_GTID_LOG_EVENT: event.NotImplementedEvent, # MariaDB GTID constants.MARIADB_ANNOTATE_ROWS_EVENT: event.MariadbAnnotateRowsEvent, @@ -93,26 +81,28 @@ class BinLogPacketWrapper(object): constants.MARIADB_START_ENCRYPTION_EVENT: event.MariadbStartEncryptionEvent } - def __init__(self, from_packet, table_map, - ctl_connection, - mysql_version, - use_checksum, - allowed_events, - only_tables, - ignored_tables, - only_schemas, - ignored_schemas, - freeze_schema, - fail_on_table_metadata_unavailable, - ignore_decode_errors, - verify_checksum,): + def __init__(self, + from_packet: MysqlPacket, + table_map: dict, + ctl_connection: Connection, + mysql_version: Tuple[int, int, int], + use_checksum: bool, + allowed_events: FrozenSet[Type[event.BinLogEvent]], + only_tables: Optional[List[str]], + ignored_tables: Optional[List[str]], + only_schemas: Optional[List[str]], + ignored_schemas: Optional[List[str]], + freeze_schema: bool, + fail_on_table_metadata_unavailable: bool, + ignore_decode_errors: bool, + verify_checksum: bool) -> None: # -1 because we ignore the ok byte self.read_bytes = 0 # Used when we want to override a value in the data buffer self.__data_buffer = b'' - self.packet = from_packet - self.charset = ctl_connection.charset + self.packet: MysqlPacket = from_packet + self.charset: str = ctl_connection.charset # OK value # timestamp @@ -123,13 +113,13 @@ def __init__(self, from_packet, table_map, unpack = struct.unpack(' bytes: size = int(size) self.read_bytes += size if len(self.__data_buffer) > 0: @@ -169,14 +163,15 @@ def read(self, size): return data + self.packet.read(size - len(data)) return self.packet.read(size) - def unread(self, data): - '''Push again data in data buffer. It's use when you want - to extract a bit from a value a let the rest of the code normally - read the datas''' + def unread(self, data: bytes) -> None: + """ + Push again data in data buffer. + Use to extract a bit from a value and ensure that the rest of the code reads data normally + """ self.read_bytes -= len(data) self.__data_buffer += data - def advance(self, size): + def advance(self, size: int) -> None: size = int(size) self.read_bytes += size buffer_len = len(self.__data_buffer) @@ -187,13 +182,11 @@ def advance(self, size): else: self.packet.advance(size) - def read_length_coded_binary(self): - """Read a 'Length Coded Binary' number from the data buffer. - + def read_length_coded_binary(self) -> Optional[int]: + """ + Read a 'Length Coded Binary' number from the data buffer. Length coded numbers can be anywhere from 1 to 9 bytes depending - on the value of the first byte. - - From PyMYSQL source code + on the value of the first byte. (From PyMYSQL source code) """ c = struct.unpack("!B", self.read(1))[0] if c == NULL_COLUMN: @@ -207,14 +200,12 @@ def read_length_coded_binary(self): elif c == UNSIGNED_INT64_COLUMN: return self.unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) - def read_length_coded_string(self): - """Read a 'Length Coded String' from the data buffer. - - A 'Length Coded String' consists first of a length coded - (unsigned, positive) integer represented in 1-9 bytes followed by - that many bytes of binary data. (For example "cat" would be "3cat".) - - From PyMYSQL source code + def read_length_coded_string(self) -> Optional[str]: + """ + Read a 'Length Coded String' from the data buffer. + A 'Length Coded String' consists first of a length coded (unsigned, positive) integer + represented in 1-9 bytes followed by that many bytes of binary data. + (For example, "cat" would be "3cat". - From PyMYSQL source code) """ length = self.read_length_coded_binary() if length is None: @@ -228,8 +219,10 @@ def __getattr__(self, key): raise AttributeError("%s instance has no attribute '%s'" % (self.__class__, key)) - def read_int_be_by_size(self, size): - '''Read a big endian integer values based on byte number''' + def read_int_be_by_size(self, size: int) -> int: + """ + Read a big endian integer values based on byte number + """ if size == 1: return struct.unpack('>b', self.read(size))[0] elif size == 2: @@ -243,8 +236,10 @@ def read_int_be_by_size(self, size): elif size == 8: return struct.unpack('>l', self.read(size))[0] - def read_uint_by_size(self, size): - '''Read a little endian integer values based on byte number''' + def read_uint_by_size(self, size: int) -> int: + """ + Read a little endian integer values based on byte number + """ if size == 1: return self.read_uint8() elif size == 2: @@ -262,19 +257,18 @@ def read_uint_by_size(self, size): elif size == 8: return self.read_uint64() - def read_length_coded_pascal_string(self, size): - """Read a string with length coded using pascal style. + def read_length_coded_pascal_string(self, size: int) -> bytes: + """ + Read a string with length coded using pascal style. The string start by the size of the string """ length = self.read_uint_by_size(size) return self.read(length) - def read_variable_length_string(self): - """Read a variable length string where the first 1-5 bytes stores the - length of the string. - - For each byte, the first bit being high indicates another byte must be - read. + def read_variable_length_string(self) -> bytes: + """ + Read a variable length string where the first 1-5 bytes stores the length of the string. + For each byte, the first bit being high indicates another byte must be read. """ byte = 0x80 length = 0 @@ -285,82 +279,82 @@ def read_variable_length_string(self): bits_read = bits_read + 7 return self.read(length) - def read_int24(self): + def read_int24(self) -> int: a, b, c = struct.unpack("BBB", self.read(3)) res = a | (b << 8) | (c << 16) if res >= 0x800000: res -= 0x1000000 return res - def read_int24_be(self): + def read_int24_be(self) -> int: a, b, c = struct.unpack('BBB', self.read(3)) res = (a << 16) | (b << 8) | c if res >= 0x800000: res -= 0x1000000 return res - def read_uint8(self): + def read_uint8(self) -> int: return struct.unpack(' int: return struct.unpack(' int: return struct.unpack(' int: a, b, c = struct.unpack(" int: return struct.unpack(' int: return struct.unpack(' int: a, b = struct.unpack(" int: a, b = struct.unpack(">IB", self.read(5)) return b + (a << 8) - def read_uint48(self): + def read_uint48(self) -> int: a, b, c = struct.unpack(" int: a, b, c = struct.unpack(" int: return struct.unpack(' int: return struct.unpack(' int: return struct.unpack(' Optional[Union[int, Tuple[str, int]]]: try: - return struct.unpack('B', n[0])[0] \ - + (struct.unpack('B', n[1])[0] << 8) \ - + (struct.unpack('B', n[2])[0] << 16) + return struct.unpack('B', n[0:1])[0] \ + + (struct.unpack('B', n[1:2])[0] << 8) \ + + (struct.unpack('B', n[2:3])[0] << 16) except TypeError: return n[0] + (n[1] << 8) + (n[2] << 16) - def unpack_int32(self, n): + def unpack_int32(self, n: bytes) -> Optional[Union[int, Tuple[str, int]]]: try: - return struct.unpack('B', n[0])[0] \ - + (struct.unpack('B', n[1])[0] << 8) \ - + (struct.unpack('B', n[2])[0] << 16) \ - + (struct.unpack('B', n[3])[0] << 24) + return struct.unpack('B', n[0:1])[0] \ + + (struct.unpack('B', n[1:2])[0] << 8) \ + + (struct.unpack('B', n[2:3])[0] << 16) \ + + (struct.unpack('B', n[3:4])[0] << 24) except TypeError: return n[0] + (n[1] << 8) + (n[2] << 16) + (n[3] << 24) - def read_binary_json(self, size): + def read_binary_json(self, size: int) -> Optional[str]: length = self.read_uint_by_size(size) if length == 0: # handle NULL value @@ -371,7 +365,10 @@ def read_binary_json(self, size): return self.read_binary_json_type(t, length) - def read_binary_json_type(self, t, length): + def read_binary_json_type(self, t: int, length: int) \ + -> Optional[Union[ + Dict[bytes, Union[bool, str, None]], + List[int], bool, int, bytes]]: large = (t in (JSONB_TYPE_LARGE_OBJECT, JSONB_TYPE_LARGE_ARRAY)) if t in (JSONB_TYPE_SMALL_OBJECT, JSONB_TYPE_LARGE_OBJECT): return self.read_binary_json_object(length - 1, large) @@ -404,7 +401,7 @@ def read_binary_json_type(self, t, length): raise ValueError('Json type %d is not handled' % t) - def read_binary_json_type_inlined(self, t, large): + def read_binary_json_type_inlined(self, t: int, large: bool) -> Optional[Union[bool, int]]: if t == JSONB_TYPE_LITERAL: value = self.read_uint32() if large else self.read_uint16() if value == JSONB_LITERAL_NULL: @@ -424,7 +421,8 @@ def read_binary_json_type_inlined(self, t, large): raise ValueError('Json type %d is not handled' % t) - def read_binary_json_object(self, length, large): + def read_binary_json_object(self, length: int, large: bool) \ + -> Dict[bytes, Union[bool, str, None]]: if large: elements = self.read_uint32() size = self.read_uint32() @@ -438,13 +436,13 @@ def read_binary_json_object(self, length, large): if large: key_offset_lengths = [( self.read_uint32(), # offset (we don't actually need that) - self.read_uint16() # size of the key - ) for _ in range(elements)] + self.read_uint16() # size of the key + ) for _ in range(elements)] else: key_offset_lengths = [( self.read_uint16(), # offset (we don't actually need that) - self.read_uint16() # size of key - ) for _ in range(elements)] + self.read_uint16() # size of key + ) for _ in range(elements)] value_type_inlined_lengths = [read_offset_or_inline(self, large) for _ in range(elements)] @@ -462,7 +460,7 @@ def read_binary_json_object(self, length, large): return out - def read_binary_json_array(self, length, large): + def read_binary_json_array(self, length: int, large: bool) -> List[int]: if large: elements = self.read_uint32() size = self.read_uint32() @@ -477,20 +475,18 @@ def read_binary_json_array(self, length, large): read_offset_or_inline(self, large) for _ in range(elements)] - def _read(x): + def _read(x: Tuple[int, Optional[bytes], Optional[Union[bool, int]]]) -> int: if x[1] is None: return x[2] return self.read_binary_json_type(x[0], length) return [_read(x) for x in values_type_offset_inline] - def read_string(self): - """Read a 'Length Coded String' from the data buffer. - + def read_string(self) -> bytes: + """ + Read a 'Length Coded String' from the data buffer. Read __data_buffer until NULL character (0 = \0 = \x00) - - Returns: - Binary string parsed from __data_buffer + :return string: Binary string parsed from __data_buffer """ string = b'' while True: @@ -500,3 +496,18 @@ def read_string(self): string += char return string + + +def read_offset_or_inline(packet: Union[MysqlPacket, BinLogPacketWrapper], large: bool) \ + -> Tuple[int, Optional[bytes], Optional[Union[bool, int]]]: + t = packet.read_uint8() + + if t in (JSONB_TYPE_LITERAL, + JSONB_TYPE_INT16, JSONB_TYPE_UINT16): + return t, None, packet.read_binary_json_type_inlined(t, large) + if large and t in (JSONB_TYPE_INT32, JSONB_TYPE_UINT32): + return t, None, packet.read_binary_json_type_inlined(t, large) + + if large: + return t, packet.read_uint32(), None + return t, packet.read_uint16(), None