Skip to content

Commit

Permalink
Improve type hints for many __getitem__ impls
Browse files Browse the repository at this point in the history
  • Loading branch information
negasora committed Apr 22, 2024
1 parent 3256028 commit 72cb4b0
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 11 deletions.
8 changes: 7 additions & 1 deletion python/binaryview.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import os
import uuid
from typing import Callable, Generator, Optional, Union, Tuple, List, Mapping, Any, \
Iterator, Iterable, KeysView, ItemsView, ValuesView, Dict
Iterator, Iterable, KeysView, ItemsView, ValuesView, Dict, overload
from dataclasses import dataclass
from enum import IntFlag

Expand Down Expand Up @@ -1985,6 +1985,12 @@ def __next__(self):
self._n += 1
return _function.Function(self._view, func)

@overload
def __getitem__(self, i: int) -> '_function.Function': ...

@overload
def __getitem__(self, i: slice) -> List['_function.Function']: ...

def __getitem__(self, i: Union[int, slice]) -> Union['_function.Function', List['_function.Function']]:
if isinstance(i, int):
if i < 0:
Expand Down
43 changes: 39 additions & 4 deletions python/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import ctypes
import inspect
from typing import Generator, Optional, List, Tuple, Union, Mapping, Any, Dict
from typing import Generator, Optional, List, Tuple, Union, Mapping, Any, Dict, overload
from dataclasses import dataclass

# Binary Ninja components
Expand Down Expand Up @@ -211,6 +211,12 @@ def __next__(self) -> 'basicblock.BasicBlock':
self._n += 1
return self._function._instantiate_block(block)

@overload
def __getitem__(self, i: int) -> 'basicblock.BasicBlock': ...

@overload
def __getitem__(self, i: slice) -> List['basicblock.BasicBlock']: ...

def __getitem__(self, i: Union[int, slice]) -> Union['basicblock.BasicBlock', List['basicblock.BasicBlock']]:
if isinstance(i, int):
if i < 0:
Expand All @@ -237,6 +243,12 @@ class LowLevelILBasicBlockList(BasicBlockList):
def __repr__(self):
return f"<LowLevelILBasicBlockList {len(self)} BasicBlocks: {list(self)}>"

@overload
def __getitem__(self, i: int) -> 'lowlevelil.LowLevelILBasicBlock': ...

@overload
def __getitem__(self, i: slice) -> List['lowlevelil.LowLevelILBasicBlock']: ...

def __getitem__(
self, i: Union[int, slice]
) -> Union['lowlevelil.LowLevelILBasicBlock', List['lowlevelil.LowLevelILBasicBlock']]:
Expand All @@ -250,6 +262,12 @@ class MediumLevelILBasicBlockList(BasicBlockList):
def __repr__(self):
return f"<MediumLevelILBasicBlockList {len(self)} BasicBlocks: {list(self)}>"

@overload
def __getitem__(self, i: int) -> 'mediumlevelil.MediumLevelILBasicBlock': ...

@overload
def __getitem__(self, i: slice) -> List['mediumlevelil.MediumLevelILBasicBlock']: ...

def __getitem__(
self, i: Union[int, slice]
) -> Union['mediumlevelil.MediumLevelILBasicBlock', List['mediumlevelil.MediumLevelILBasicBlock']]:
Expand All @@ -263,6 +281,12 @@ class HighLevelILBasicBlockList(BasicBlockList):
def __repr__(self):
return f"<HighLevelILBasicBlockList {len(self)} BasicBlocks: {list(self)}>"

@overload
def __getitem__(self, i: int) -> 'highlevelil.HighLevelILBasicBlock': ...

@overload
def __getitem__(self, i: slice) -> List['highlevelil.HighLevelILBasicBlock']: ...

def __getitem__(
self, i: Union[int, slice]
) -> Union['highlevelil.HighLevelILBasicBlock', List['highlevelil.HighLevelILBasicBlock']]:
Expand Down Expand Up @@ -304,10 +328,15 @@ def __next__(self) -> Tuple['architecture.Architecture', int, 'binaryview.Tag']:
self._n += 1
return arch, address, binaryview.Tag(core_tag)

@overload
def __getitem__(self, i: int) -> Tuple['architecture.Architecture', int, 'binaryview.Tag']: ...

@overload
def __getitem__(self, i: slice) -> List[Tuple['architecture.Architecture', int, 'binaryview.Tag']]: ...

def __getitem__(
self, i: Union[int, slice]
) -> Union[Tuple['architecture.Architecture', int, 'binaryview.Tag'], List[Tuple['architecture.Architecture', int,
'binaryview.Tag']]]:
) -> Union[Tuple['architecture.Architecture', int, 'binaryview.Tag'], List[Tuple['architecture.Architecture', int, 'binaryview.Tag']]]:
if isinstance(i, int):
if i < 0:
i = len(self) + i
Expand Down Expand Up @@ -400,7 +429,13 @@ def __ge__(self, other: 'Function') -> bool:
def __hash__(self):
return hash((self.start, self.arch, self.platform))

def __getitem__(self, i) -> Union['basicblock.BasicBlock', List['basicblock.BasicBlock']]:
@overload
def __getitem__(self, i: int) -> 'basicblock.BasicBlock': ...

@overload
def __getitem__(self, i: slice) -> List['basicblock.BasicBlock']: ...

def __getitem__(self, i: Union[int, slice]) -> Union['basicblock.BasicBlock', List['basicblock.BasicBlock']]:
return self.basic_blocks[i]

def __iter__(self) -> Generator['basicblock.BasicBlock', None, None]:
Expand Down
10 changes: 8 additions & 2 deletions python/highlevelil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import ctypes
import struct
from typing import Optional, Generator, List, Union, NewType, Tuple, ClassVar, Mapping, Set, Callable, Any, Iterator
from typing import Optional, Generator, List, Union, NewType, Tuple, ClassVar, Mapping, Set, Callable, Any, Iterator, overload
from dataclasses import dataclass
from enum import Enum

Expand Down Expand Up @@ -3089,7 +3089,13 @@ def __iter__(self) -> Generator[HighLevelILInstruction, None, None]:
for idx in range(self.start, self.end):
yield self.il_function[idx]

def __getitem__(self, idx) -> Union[List[HighLevelILInstruction], HighLevelILInstruction]:
@overload
def __getitem__(self, idx: int) -> 'HighLevelILInstruction': ...

@overload
def __getitem__(self, idx: slice) -> List['HighLevelILInstruction']: ...

def __getitem__(self, idx: Union[int, slice]) -> Union[List[HighLevelILInstruction], HighLevelILInstruction]:
size = self.end - self.start
if isinstance(idx, slice):
return [self[index] for index in range(*idx.indices(size))] # type: ignore
Expand Down
10 changes: 8 additions & 2 deletions python/lowlevelil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import ctypes
import struct
from typing import Generator, List, Optional, Dict, Union, Tuple, NewType, ClassVar, Set, Callable, Any, Iterator
from typing import Generator, List, Optional, Dict, Union, Tuple, NewType, ClassVar, Set, Callable, Any, Iterator, overload
from dataclasses import dataclass

# Binary Ninja components
Expand Down Expand Up @@ -5503,7 +5503,13 @@ def __iter__(self) -> Generator['LowLevelILInstruction', None, None]:
for idx in range(self.start, self.end):
yield self._il_function[idx]

def __getitem__(self, idx):
@overload
def __getitem__(self, idx: int) -> 'LowLevelILInstruction': ...

@overload
def __getitem__(self, idx: slice) -> List['LowLevelILInstruction']: ...

def __getitem__(self, idx: Union[int, slice]) -> Union['LowLevelILInstruction', List['LowLevelILInstruction']]:
size = self.end - self.start
if isinstance(idx, slice):
return [self[index] for index in range(*idx.indices(size))]
Expand Down
10 changes: 8 additions & 2 deletions python/mediumlevelil.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import ctypes
import struct
from typing import (Optional, List, Union, Mapping,
Generator, NewType, Tuple, ClassVar, Dict, Set, Callable, Any, Iterator)
Generator, NewType, Tuple, ClassVar, Dict, Set, Callable, Any, Iterator, overload)
from dataclasses import dataclass
from . import deprecation

Expand Down Expand Up @@ -3957,7 +3957,13 @@ def __iter__(self):
for idx in range(self.start, self.end):
yield self._il_function[idx]

def __getitem__(self, idx) -> Union[List['MediumLevelILInstruction'], 'MediumLevelILInstruction']:
@overload
def __getitem__(self, idx: int) -> 'MediumLevelILInstruction': ...

@overload
def __getitem__(self, idx: slice) -> List['MediumLevelILInstruction']: ...

def __getitem__(self, idx: Union[int, slice]) -> Union[List['MediumLevelILInstruction'], 'MediumLevelILInstruction']:
size = self.end - self.start
if isinstance(idx, slice):
return [self[index] for index in range(*idx.indices(size))] # type: ignore
Expand Down

0 comments on commit 72cb4b0

Please sign in to comment.