From b84f364c2571d54157ba3080a3cb87bef4e30e9d Mon Sep 17 00:00:00 2001 From: Andreas Date: Sat, 14 Dec 2019 21:08:51 +0100 Subject: [PATCH] Types (#62) * type information added --- graph/algorithms.py | 6 ++++-- graph/contract_graph.py | 21 ++++++++++++--------- graph/graph.py | 12 +++++++----- graph/graph_types.py | 18 +++++++++--------- osm/osm_types.py | 10 ++++++---- osm/xml_handler.py | 21 ++++++++++++--------- utils/geo_tools.py | 2 +- utils/timer.py | 3 ++- 8 files changed, 53 insertions(+), 40 deletions(-) diff --git a/graph/algorithms.py b/graph/algorithms.py index ec73896..b8d1619 100644 --- a/graph/algorithms.py +++ b/graph/algorithms.py @@ -4,7 +4,9 @@ import utils.timer as timer -def BFS(graph, s): +from graph.graph import Graph +from typing import Set +def BFS(graph: Graph, s: int) -> Set[int]: seen_nodes = set([s]) unvisited_nodes = deque([s]) @@ -35,7 +37,7 @@ def computeLCC(graph): return lcc -def computeLCCGraph(graph): +def computeLCCGraph(graph: Graph) -> Graph: lcc = computeLCC(graph) new_nodes = [graph.vertices[id] for id in lcc] return graphfactory.build_graph_from_vertices_edges(new_nodes, graph.edges) diff --git a/graph/contract_graph.py b/graph/contract_graph.py index ba0d000..72d502b 100644 --- a/graph/contract_graph.py +++ b/graph/contract_graph.py @@ -5,6 +5,9 @@ import utils.timer as timer +from graph.graph import Graph +from graph.graph_types import Edge, SimpleEdge, Vertex +from typing import Dict, List, Optional, Set, Tuple @timer.timer def contract(graph): all_new_edges = find_new_edges(graph) @@ -15,11 +18,11 @@ def contract(graph): return graphfactory.build_graph_from_vertices_edges(nodes, filtered_edges) -def get_nodes(graph, node_ids): +def get_nodes(graph: Graph, node_ids: Set[int]) -> List[Vertex]: return list(map(lambda node_id: graph.get_node(node_id), node_ids)) -def gather_node_ids(edges): +def gather_node_ids(edges: List[SimpleEdge]) -> Set[int]: print("\t gathering nodes...") node_ids = set() for e in edges: @@ -28,7 +31,7 @@ def gather_node_ids(edges): return node_ids -def remove_duplicates(edges): +def remove_duplicates(edges: List[SimpleEdge]) -> List[SimpleEdge]: print("\t removing duplicate edges...") added_edges = set() filtered_edges = [] @@ -40,7 +43,7 @@ def remove_duplicates(edges): return filtered_edges -def find_new_edges(graph): +def find_new_edges(graph: Graph) -> List[SimpleEdge]: edge_by_s_t = edge_mapping(graph) all_new_edges = [] for node_id in range(len(graph.vertices)): @@ -56,7 +59,7 @@ def find_new_edges(graph): return all_new_edges -def edge_mapping(graph): +def edge_mapping(graph: Graph) -> Dict[Tuple[int, int], Edge]: edge_by_s_t = {} for edge in graph.edges: edge_by_s_t[(edge.s, edge.t)] = edge @@ -69,11 +72,11 @@ def edge_mapping(graph): return edge_by_s_t -def is_important_node(graph, node_id): +def is_important_node(graph: Graph, node_id: int) -> bool: return len(graph.all_neighbors(node_id)) != 2 -def get_edges(nodes, edges_by_s_t): +def get_edges(nodes: List[int], edges_by_s_t: Dict[Tuple[int, int], Edge]) -> List[Edge]: if nodes: edges = [] for i in range(len(nodes) - 1): @@ -83,7 +86,7 @@ def get_edges(nodes, edges_by_s_t): return [] -def merge_edges(edges): +def merge_edges(edges: List[Edge]) -> Optional[SimpleEdge]: if edges: s, t = edges[0].s, edges[-1].t if s != t: @@ -92,7 +95,7 @@ def merge_edges(edges): return None -def nodes_to_next_important_node(graph, start_node, next_node): +def nodes_to_next_important_node(graph: Graph, start_node: int, next_node: int) -> List[int]: if start_node == next_node: print("\t something is wrong here...") return [] diff --git a/graph/graph.py b/graph/graph.py index daacbe6..2846dc0 100644 --- a/graph/graph.py +++ b/graph/graph.py @@ -1,13 +1,15 @@ +from graph.graph_types import Edge, SimpleEdge, Vertex +from typing import List, Union class Graph(object): - def __init__(self): + def __init__(self) -> None: self.edges = [] self.vertices = [] self.outneighbors = [] self.inneighbors = [] - def add_edge(self, edge): + def add_edge(self, edge: Union[SimpleEdge, Edge]) -> None: self.edges.append(edge) if edge.forward: @@ -18,12 +20,12 @@ def add_edge(self, edge): self.outneighbors[edge.t].add(edge.s) self.inneighbors[edge.s].add(edge.t) - def add_node(self, vertex): + def add_node(self, vertex: Vertex) -> None: self.vertices.append(vertex) self.outneighbors.append(set()) self.inneighbors.append(set()) - def get_node(self, node_id): + def get_node(self, node_id: int) -> Vertex: return self.vertices[node_id] def edge_description(self, edge_id): @@ -32,5 +34,5 @@ def edge_description(self, edge_id): def edge_name(self, edge_id): return "{}".format(self.edges[edge_id].name) - def all_neighbors(self, node_id): + def all_neighbors(self, node_id: int) -> List[int]: return list(self.outneighbors[node_id].union(self.inneighbors[node_id])) diff --git a/graph/graph_types.py b/graph/graph_types.py index 344ab9c..4ef8937 100644 --- a/graph/graph_types.py +++ b/graph/graph_types.py @@ -1,19 +1,19 @@ class Vertex(object): __slots__ = ["id", "lat", "lon"] - def __init__(self, id, lat, lon): + def __init__(self, id: int, lat: float, lon: float) -> None: self.id = id self.lat, self.lon = lat, lon @property - def description(self): + def description(self) -> str: return "{} {} {}".format(self.id, self.lat, self.lon) class Edge(object): __slots__ = ["s", "t", "length", "highway", "max_v", "forward", "backward", "name"] - def __init__(self, s, t, length, highway, max_v, f, b, name): + def __init__(self, s: int, t: int, length: float, highway: str, max_v: int, f: bool, b: bool, name: str) -> None: self.s, self.t = s, t self.length = length self.highway = highway @@ -22,7 +22,7 @@ def __init__(self, s, t, length, highway, max_v, f, b, name): self.name = name @property - def description(self): + def description(self) -> str: both_directions = "1" if self.forward and self.backward else "0" return "{} {} {} {} {} {}".format(self.s, self.t, self.length, self.highway, self.max_v, both_directions) @@ -30,19 +30,19 @@ def description(self): class SimpleEdge(object): __slots__ = ["s", "t", "length", "name"] - def __init__(self, s, t, length): + def __init__(self, s: int, t: int, length: float) -> None: self.s, self.t = s, t self.length = length self.name = "" @property - def forward(self): + def forward(self) -> bool: return True @property - def backward(self): + def backward(self) -> bool: return True @property - def description(self): - return "{} {} {}".format(self.s, self.t, self.length) \ No newline at end of file + def description(self) -> str: + return "{} {} {}".format(self.s, self.t, self.length) diff --git a/osm/osm_types.py b/osm/osm_types.py index b70e555..b8b86a1 100644 --- a/osm/osm_types.py +++ b/osm/osm_types.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- +from typing import Optional + class OSMWay: __slots__ = ["osm_id", "nodes", "highway", "area", "max_speed", "direction", "forward", "backward", "name"] - def __init__(self, osm_id): + def __init__(self, osm_id: int) -> None: self.osm_id = osm_id self.nodes = [] self.highway = None @@ -13,13 +15,13 @@ def __init__(self, osm_id): self.backward = True self.name = "" - def add_node(self, osm_id): + def add_node(self, osm_id: int) -> None: self.nodes.append(osm_id) class OSMNode(object): __slots__ = ["lat", "lon", "osm_id"] - def __init__(self, osm_id, lat=None, lon=None): + def __init__(self, osm_id: int, lat: Optional[float] = None, lon: Optional[float] = None) -> None: self.osm_id = osm_id - self.lat, self.lon = lat, lon \ No newline at end of file + self.lat, self.lon = lat, lon diff --git a/osm/xml_handler.py b/osm/xml_handler.py index 6e8c2f4..a95b0c8 100644 --- a/osm/xml_handler.py +++ b/osm/xml_handler.py @@ -4,6 +4,9 @@ from osm.osm_types import OSMWay, OSMNode +from osm.way_parser_helper import WayParserHelper +from typing import Optional, Set +from xml.sax.xmlreader import AttributesImpl try: intern = sys.intern except AttributeError: @@ -12,13 +15,13 @@ class PercentageFile(object): - def __init__(self, filename): + def __init__(self, filename: str) -> None: self.size = os.stat(filename)[6] self.delivered = 0 self.f = open(filename) self.percentages = [1000] + [100 - 10 * x for x in range(0, 11)] - def read(self, size=None): + def read(self, size: Optional[int] = None) -> str: if size is None: self.delivered = self.size return self.f.read() @@ -34,21 +37,21 @@ def read(self, size=None): self.percentages = self.percentages[:-1] return data - def close(self): + def close(self) -> None: self.f.close() @property - def percentage(self): + def percentage(self) -> float: return float(self.delivered) / self.size * 100.0 class NodeHandler(xml.sax.ContentHandler): - def __init__(self, found_nodes): + def __init__(self, found_nodes: Set[int]) -> None: self.found_nodes = found_nodes self.nodes = {} - def startElement(self, tag, attributes): + def startElement(self, tag: str, attributes: AttributesImpl) -> None: if tag == "node": osm_id = int(attributes["id"]) if osm_id not in self.found_nodes: @@ -59,7 +62,7 @@ def startElement(self, tag, attributes): class WayHandler(xml.sax.ContentHandler): - def __init__(self, parser_helper): + def __init__(self, parser_helper: WayParserHelper) -> None: # stores all found ways self.found_ways = [] self.found_nodes = set() @@ -69,7 +72,7 @@ def __init__(self, parser_helper): self.parser_helper = parser_helper - def startElement(self, tag, attributes): + def startElement(self, tag: str, attributes: AttributesImpl) -> None: if tag == "way": self.start_tag_found = True self.current_way = OSMWay(int(attributes["id"])) @@ -110,7 +113,7 @@ def startElement(self, tag, attributes): e = sys.exc_info()[0] print("Error while parsing: {}".format(e)) - def endElement(self, tag): + def endElement(self, tag: str) -> None: if tag == "way": self.start_tag_found = False diff --git a/utils/geo_tools.py b/utils/geo_tools.py index aee2ffd..bedb4d2 100644 --- a/utils/geo_tools.py +++ b/utils/geo_tools.py @@ -2,7 +2,7 @@ # from http://www.johndcook.com/blog/python_longitude_latitude/ -def distance(lat1, lon1, lat2, lon2): +def distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: # Convert latitude and longitude to # spherical coordinates in radians. diff --git a/utils/timer.py b/utils/timer.py index 710707a..b43ea59 100644 --- a/utils/timer.py +++ b/utils/timer.py @@ -1,7 +1,8 @@ import time -def timer(function): +from typing import Callable +def timer(function: Callable) -> Callable: def wrapper(*args, **kwargs): start_time = time.time() print("starting {}".format(function.__name__))