Skip to content

Commit

Permalink
Types (#62)
Browse files Browse the repository at this point in the history
* type information added
  • Loading branch information
AndGem authored Dec 14, 2019
1 parent 2583f50 commit b84f364
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 40 deletions.
6 changes: 4 additions & 2 deletions graph/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import utils.timer as timer


def BFS(graph, s):
from graph.graph import Graph
from typing import Set
def BFS(graph: Graph, s: int) -> Set[int]:
seen_nodes = set([s])
unvisited_nodes = deque([s])

Expand Down Expand Up @@ -35,7 +37,7 @@ def computeLCC(graph):
return lcc


def computeLCCGraph(graph):
def computeLCCGraph(graph: Graph) -> Graph:
lcc = computeLCC(graph)
new_nodes = [graph.vertices[id] for id in lcc]
return graphfactory.build_graph_from_vertices_edges(new_nodes, graph.edges)
21 changes: 12 additions & 9 deletions graph/contract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import utils.timer as timer


from graph.graph import Graph
from graph.graph_types import Edge, SimpleEdge, Vertex
from typing import Dict, List, Optional, Set, Tuple
@timer.timer
def contract(graph):
all_new_edges = find_new_edges(graph)
Expand All @@ -15,11 +18,11 @@ def contract(graph):
return graphfactory.build_graph_from_vertices_edges(nodes, filtered_edges)


def get_nodes(graph, node_ids):
def get_nodes(graph: Graph, node_ids: Set[int]) -> List[Vertex]:
return list(map(lambda node_id: graph.get_node(node_id), node_ids))


def gather_node_ids(edges):
def gather_node_ids(edges: List[SimpleEdge]) -> Set[int]:
print("\t gathering nodes...")
node_ids = set()
for e in edges:
Expand All @@ -28,7 +31,7 @@ def gather_node_ids(edges):
return node_ids


def remove_duplicates(edges):
def remove_duplicates(edges: List[SimpleEdge]) -> List[SimpleEdge]:
print("\t removing duplicate edges...")
added_edges = set()
filtered_edges = []
Expand All @@ -40,7 +43,7 @@ def remove_duplicates(edges):
return filtered_edges


def find_new_edges(graph):
def find_new_edges(graph: Graph) -> List[SimpleEdge]:
edge_by_s_t = edge_mapping(graph)
all_new_edges = []
for node_id in range(len(graph.vertices)):
Expand All @@ -56,7 +59,7 @@ def find_new_edges(graph):
return all_new_edges


def edge_mapping(graph):
def edge_mapping(graph: Graph) -> Dict[Tuple[int, int], Edge]:
edge_by_s_t = {}
for edge in graph.edges:
edge_by_s_t[(edge.s, edge.t)] = edge
Expand All @@ -69,11 +72,11 @@ def edge_mapping(graph):
return edge_by_s_t


def is_important_node(graph, node_id):
def is_important_node(graph: Graph, node_id: int) -> bool:
return len(graph.all_neighbors(node_id)) != 2


def get_edges(nodes, edges_by_s_t):
def get_edges(nodes: List[int], edges_by_s_t: Dict[Tuple[int, int], Edge]) -> List[Edge]:
if nodes:
edges = []
for i in range(len(nodes) - 1):
Expand All @@ -83,7 +86,7 @@ def get_edges(nodes, edges_by_s_t):
return []


def merge_edges(edges):
def merge_edges(edges: List[Edge]) -> Optional[SimpleEdge]:
if edges:
s, t = edges[0].s, edges[-1].t
if s != t:
Expand All @@ -92,7 +95,7 @@ def merge_edges(edges):
return None


def nodes_to_next_important_node(graph, start_node, next_node):
def nodes_to_next_important_node(graph: Graph, start_node: int, next_node: int) -> List[int]:
if start_node == next_node:
print("\t something is wrong here...")
return []
Expand Down
12 changes: 7 additions & 5 deletions graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@

from graph.graph_types import Edge, SimpleEdge, Vertex
from typing import List, Union
class Graph(object):

def __init__(self):
def __init__(self) -> None:
self.edges = []
self.vertices = []
self.outneighbors = []
self.inneighbors = []

def add_edge(self, edge):
def add_edge(self, edge: Union[SimpleEdge, Edge]) -> None:
self.edges.append(edge)

if edge.forward:
Expand All @@ -18,12 +20,12 @@ def add_edge(self, edge):
self.outneighbors[edge.t].add(edge.s)
self.inneighbors[edge.s].add(edge.t)

def add_node(self, vertex):
def add_node(self, vertex: Vertex) -> None:
self.vertices.append(vertex)
self.outneighbors.append(set())
self.inneighbors.append(set())

def get_node(self, node_id):
def get_node(self, node_id: int) -> Vertex:
return self.vertices[node_id]

def edge_description(self, edge_id):
Expand All @@ -32,5 +34,5 @@ def edge_description(self, edge_id):
def edge_name(self, edge_id):
return "{}".format(self.edges[edge_id].name)

def all_neighbors(self, node_id):
def all_neighbors(self, node_id: int) -> List[int]:
return list(self.outneighbors[node_id].union(self.inneighbors[node_id]))
18 changes: 9 additions & 9 deletions graph/graph_types.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
class Vertex(object):
__slots__ = ["id", "lat", "lon"]

def __init__(self, id, lat, lon):
def __init__(self, id: int, lat: float, lon: float) -> None:
self.id = id
self.lat, self.lon = lat, lon

@property
def description(self):
def description(self) -> str:
return "{} {} {}".format(self.id, self.lat, self.lon)


class Edge(object):
__slots__ = ["s", "t", "length", "highway", "max_v", "forward", "backward", "name"]

def __init__(self, s, t, length, highway, max_v, f, b, name):
def __init__(self, s: int, t: int, length: float, highway: str, max_v: int, f: bool, b: bool, name: str) -> None:
self.s, self.t = s, t
self.length = length
self.highway = highway
Expand All @@ -22,27 +22,27 @@ def __init__(self, s, t, length, highway, max_v, f, b, name):
self.name = name

@property
def description(self):
def description(self) -> str:
both_directions = "1" if self.forward and self.backward else "0"
return "{} {} {} {} {} {}".format(self.s, self.t, self.length, self.highway, self.max_v, both_directions)


class SimpleEdge(object):
__slots__ = ["s", "t", "length", "name"]

def __init__(self, s, t, length):
def __init__(self, s: int, t: int, length: float) -> None:
self.s, self.t = s, t
self.length = length
self.name = ""

@property
def forward(self):
def forward(self) -> bool:
return True

@property
def backward(self):
def backward(self) -> bool:
return True

@property
def description(self):
return "{} {} {}".format(self.s, self.t, self.length)
def description(self) -> str:
return "{} {} {}".format(self.s, self.t, self.length)
10 changes: 6 additions & 4 deletions osm/osm_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
from typing import Optional

class OSMWay:
__slots__ = ["osm_id", "nodes", "highway", "area", "max_speed", "direction", "forward", "backward", "name"]

def __init__(self, osm_id):
def __init__(self, osm_id: int) -> None:
self.osm_id = osm_id
self.nodes = []
self.highway = None
Expand All @@ -13,13 +15,13 @@ def __init__(self, osm_id):
self.backward = True
self.name = ""

def add_node(self, osm_id):
def add_node(self, osm_id: int) -> None:
self.nodes.append(osm_id)


class OSMNode(object):
__slots__ = ["lat", "lon", "osm_id"]

def __init__(self, osm_id, lat=None, lon=None):
def __init__(self, osm_id: int, lat: Optional[float] = None, lon: Optional[float] = None) -> None:
self.osm_id = osm_id
self.lat, self.lon = lat, lon
self.lat, self.lon = lat, lon
21 changes: 12 additions & 9 deletions osm/xml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from osm.osm_types import OSMWay, OSMNode

from osm.way_parser_helper import WayParserHelper
from typing import Optional, Set
from xml.sax.xmlreader import AttributesImpl
try:
intern = sys.intern
except AttributeError:
Expand All @@ -12,13 +15,13 @@

class PercentageFile(object):

def __init__(self, filename):
def __init__(self, filename: str) -> None:
self.size = os.stat(filename)[6]
self.delivered = 0
self.f = open(filename)
self.percentages = [1000] + [100 - 10 * x for x in range(0, 11)]

def read(self, size=None):
def read(self, size: Optional[int] = None) -> str:
if size is None:
self.delivered = self.size
return self.f.read()
Expand All @@ -34,21 +37,21 @@ def read(self, size=None):
self.percentages = self.percentages[:-1]
return data

def close(self):
def close(self) -> None:
self.f.close()

@property
def percentage(self):
def percentage(self) -> float:
return float(self.delivered) / self.size * 100.0


class NodeHandler(xml.sax.ContentHandler):

def __init__(self, found_nodes):
def __init__(self, found_nodes: Set[int]) -> None:
self.found_nodes = found_nodes
self.nodes = {}

def startElement(self, tag, attributes):
def startElement(self, tag: str, attributes: AttributesImpl) -> None:
if tag == "node":
osm_id = int(attributes["id"])
if osm_id not in self.found_nodes:
Expand All @@ -59,7 +62,7 @@ def startElement(self, tag, attributes):

class WayHandler(xml.sax.ContentHandler):

def __init__(self, parser_helper):
def __init__(self, parser_helper: WayParserHelper) -> None:
# stores all found ways
self.found_ways = []
self.found_nodes = set()
Expand All @@ -69,7 +72,7 @@ def __init__(self, parser_helper):

self.parser_helper = parser_helper

def startElement(self, tag, attributes):
def startElement(self, tag: str, attributes: AttributesImpl) -> None:
if tag == "way":
self.start_tag_found = True
self.current_way = OSMWay(int(attributes["id"]))
Expand Down Expand Up @@ -110,7 +113,7 @@ def startElement(self, tag, attributes):
e = sys.exc_info()[0]
print("Error while parsing: {}".format(e))

def endElement(self, tag):
def endElement(self, tag: str) -> None:
if tag == "way":
self.start_tag_found = False

Expand Down
2 changes: 1 addition & 1 deletion utils/geo_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


# from http://www.johndcook.com/blog/python_longitude_latitude/
def distance(lat1, lon1, lat2, lon2):
def distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:

# Convert latitude and longitude to
# spherical coordinates in radians.
Expand Down
3 changes: 2 additions & 1 deletion utils/timer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time


def timer(function):
from typing import Callable
def timer(function: Callable) -> Callable:
def wrapper(*args, **kwargs):
start_time = time.time()
print("starting {}".format(function.__name__))
Expand Down

0 comments on commit b84f364

Please sign in to comment.