Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make providers configurable, add config file for easier configuration #15

Merged
merged 13 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ Optional arguments:
It is enforced on per-second basis, to avoid bursts.
- `--traveltime-max-rpm [int]`: Set max number of parallel requests sent to TravelTime API per minute. Default is 60.
It is enforced on per-second basis, to avoid bursts.
- `--providers [providers]`: The providers you want to compare to TravelTime (e.g., --providers google mapbox).
arnasbr marked this conversation as resolved.
Show resolved Hide resolved
arnasbr marked this conversation as resolved.
Show resolved Hide resolved
Possible options: google, tomtom, here, mapbox, osrm, openroutes. TravelTime is included regardless of input.
By default all possible providers are included.

Example:

Expand Down
12 changes: 9 additions & 3 deletions src/traveltime_google_comparison/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,18 @@ def format_results_for_csv(
def run_analysis(
results: DataFrame, output_file: str, quantile: float, api_providers: List[str]
):
results_with_differences = calculate_differences(results, api_providers)
log_results(results_with_differences, quantile, api_providers)
providers_without_traveltime = [p for p in api_providers if p != TRAVELTIME_API]

results_with_differences = calculate_differences(
results, providers_without_traveltime
)
log_results(results_with_differences, quantile, providers_without_traveltime)

logging.info(f"Detailed results can be found in {output_file} file")

formatted_results = format_results_for_csv(results_with_differences, api_providers)
formatted_results = format_results_for_csv(
results_with_differences, providers_without_traveltime
)

formatted_results.to_csv(output_file, index=False)

Expand Down
26 changes: 12 additions & 14 deletions src/traveltime_google_comparison/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
TRAVELTIME_API = "traveltime"
OPENROUTES_API = "openroutes"

ALL_PROVIDERS = [
GOOGLE_API,
TOMTOM_API,
HERE_API,
MAPBOX_API,
OSRM_API,
OPENROUTES_API,
TRAVELTIME_API,
]


def get_capitalized_provider_name(provider: str) -> str:
if provider == "google":
Expand Down Expand Up @@ -146,26 +156,14 @@ async def collect_travel_times(
capitalized_providers_str = ", ".join(
[get_capitalized_provider_name(provider) for provider in providers]
)
logger.info(
f"Sending {len(tasks)} requests to {capitalized_providers_str} and TravelTime APIs"
)
logger.info(f"Sending {len(tasks)} requests to {capitalized_providers_str} APIs")

results = await asyncio.gather(*tasks)

results_df = pd.DataFrame(results)
deduplicated = results_df.groupby(
[Fields.ORIGIN, Fields.DESTINATION, Fields.DEPARTURE_TIME], as_index=False
).agg(
{
Fields.TRAVEL_TIME[GOOGLE_API]: "first",
Fields.TRAVEL_TIME[TOMTOM_API]: "first",
Fields.TRAVEL_TIME[HERE_API]: "first",
Fields.TRAVEL_TIME[OSRM_API]: "first",
Fields.TRAVEL_TIME[OPENROUTES_API]: "first",
Fields.TRAVEL_TIME[MAPBOX_API]: "first",
Fields.TRAVEL_TIME[TRAVELTIME_API]: "first",
}
)
).agg({Fields.TRAVEL_TIME[provider]: "first" for provider in providers})
deduplicated.to_csv(args.output, index=False)
return deduplicated

Expand Down
18 changes: 17 additions & 1 deletion src/traveltime_google_comparison/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Mode(Enum):

def parse_args():
parser = argparse.ArgumentParser(
description="Fetch and compare travel times from Google Directions API and TravelTime Routes API"
description="Fetch and compare travel times from TravelTime Routes API and it's competitors"
)
parser.add_argument("--input", required=True, help="Input CSV file path")
parser.add_argument("--output", required=True, help="Output CSV file path")
Expand All @@ -50,6 +50,22 @@ def parse_args():
required=True,
help="Non-abbreviated time zone identifier e.g. Europe/London",
)
parser.add_argument(
"--providers",
nargs="+",
default=[
"google",
"tomtom",
"here",
"mapbox",
"osrm",
"openroutes",
"traveltime",
],
help="""List of providers to use and compare against TravelTime (e.g., --providers google mapbox).
Possible options: google, tomtom, here, mapbox, osrm, openroutes.
TravelTime is included regardless of input.""",
)
parser.add_argument(
"--google-max-rpm",
required=False,
Expand Down
48 changes: 17 additions & 31 deletions src/traveltime_google_comparison/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@
from traveltime_google_comparison import config
from traveltime_google_comparison.analysis import run_analysis
from traveltime_google_comparison.collect import (
OSRM_API,
OPENROUTES_API,
HERE_API,
MAPBOX_API,
Fields,
GOOGLE_API,
TRAVELTIME_API,
TOMTOM_API,
ALL_PROVIDERS,
)
from traveltime_google_comparison.requests import factory

Expand All @@ -28,8 +23,16 @@


async def run():
providers = [GOOGLE_API, TOMTOM_API, HERE_API, MAPBOX_API, OSRM_API, OPENROUTES_API]
args = config.parse_args()

# Get all providers that should be tested against TravelTime
providers = [provider for provider in ALL_PROVIDERS if provider in args.providers]

# TravelTime always should be in the analysis, unless in the future we decide to
# allow the user to control what is the base for comparison.
if TRAVELTIME_API not in providers:
providers.append(TRAVELTIME_API)

csv = pd.read_csv(
args.input, usecols=[Fields.ORIGIN, Fields.DESTINATION]
).drop_duplicates()
Expand All @@ -38,43 +41,26 @@ async def run():
logger.info("Provided input file is empty. Exiting.")
return

request_handlers = factory.initialize_request_handlers(
args.google_max_rpm,
args.tomtom_max_rpm,
args.here_max_rpm,
args.osrm_max_rpm,
args.openroutes_max_rpm,
args.mapbox_max_rpm,
args.traveltime_max_rpm,
)
request_handlers = factory.initialize_request_handlers(providers, args)
if args.skip_data_gathering:
travel_times_df = pd.read_csv(
args.input,
usecols=[
Fields.ORIGIN,
Fields.DESTINATION,
Fields.DEPARTURE_TIME,
Fields.TRAVEL_TIME[GOOGLE_API],
Fields.TRAVEL_TIME[TOMTOM_API],
Fields.TRAVEL_TIME[HERE_API],
Fields.TRAVEL_TIME[OSRM_API],
Fields.TRAVEL_TIME[OPENROUTES_API],
Fields.TRAVEL_TIME[MAPBOX_API],
Fields.TRAVEL_TIME[TRAVELTIME_API],
],
] # base fields
+ [Fields.TRAVEL_TIME[provider] for provider in providers], # all providers
)
else:
travel_times_df = await collect.collect_travel_times(
args, csv, request_handlers, providers
)

filtered_travel_times_df = travel_times_df.loc[
travel_times_df[Fields.TRAVEL_TIME[GOOGLE_API]].notna()
& travel_times_df[Fields.TRAVEL_TIME[TOMTOM_API]].notna()
& travel_times_df[Fields.TRAVEL_TIME[HERE_API]].notna()
& travel_times_df[Fields.TRAVEL_TIME[OSRM_API]].notna()
& travel_times_df[Fields.TRAVEL_TIME[OPENROUTES_API]].notna()
& travel_times_df[Fields.TRAVEL_TIME[MAPBOX_API]].notna()
& travel_times_df[Fields.TRAVEL_TIME[TRAVELTIME_API]].notna(),
travel_times_df[[Fields.TRAVEL_TIME[provider] for provider in providers]]
.notna()
.all(axis=1),
:,
]

Expand Down
72 changes: 46 additions & 26 deletions src/traveltime_google_comparison/requests/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from argparse import Namespace
from typing import Dict, List

from traveltime_google_comparison.collect import (
TOMTOM_API,
Expand Down Expand Up @@ -32,30 +33,49 @@


def initialize_request_handlers(
google_max_rpm,
tomtom_max_rpm,
here_max_rpm,
osrm_max_rpm,
openroutes_max_rpm,
mapbox_max_rpm,
traveltime_max_rpm,
providers: List[str], args: Namespace
) -> Dict[str, BaseRequestHandler]:
google_api_key = retrieve_google_api_key()
tomtom_api_key = retrieve_tomtom_api_key()
here_api_key = retrieve_here_api_key()
mapbox_api_key = retrieve_mapbox_api_key()
openroutes_api_key = retrieve_openroutes_api_key()
credentials = retrieve_traveltime_credentials()
return {
GOOGLE_API: GoogleRequestHandler(google_api_key, google_max_rpm),
TOMTOM_API: TomTomRequestHandler(tomtom_api_key, tomtom_max_rpm),
HERE_API: HereRequestHandler(here_api_key, here_max_rpm),
OSRM_API: OSRMRequestHandler("", osrm_max_rpm),
OPENROUTES_API: OpenRoutesRequestHandler(
openroutes_api_key, openroutes_max_rpm
),
MAPBOX_API: MapboxRequestHandler(mapbox_api_key, mapbox_max_rpm),
TRAVELTIME_API: TravelTimeRequestHandler(
credentials.app_id, credentials.api_key, traveltime_max_rpm
),
def create_google_handler():
return GoogleRequestHandler(retrieve_google_api_key(), args.google_max_rpm)

def create_tomtom_handler():
return TomTomRequestHandler(retrieve_tomtom_api_key(), args.tomtom_max_rpm)

def create_here_handler():
return HereRequestHandler(retrieve_here_api_key(), args.here_max_rpm)

def create_osrm_handler():
return OSRMRequestHandler("", args.osrm_max_rpm)

def create_openroutes_handler():
return OpenRoutesRequestHandler(
retrieve_openroutes_api_key(), args.openroutes_max_rpm
)

def create_mapbox_handler():
return MapboxRequestHandler(retrieve_mapbox_api_key(), args.mapbox_max_rpm)

def create_traveltime_handler():
credentials = retrieve_traveltime_credentials()
return TravelTimeRequestHandler(
credentials.app_id, credentials.api_key, args.traveltime_max_rpm
)

handler_mapping = {
GOOGLE_API: create_google_handler,
TOMTOM_API: create_tomtom_handler,
HERE_API: create_here_handler,
OSRM_API: create_osrm_handler,
OPENROUTES_API: create_openroutes_handler,
MAPBOX_API: create_mapbox_handler,
}

handlers = {}
for provider in providers:
if provider in handler_mapping:
handlers[provider] = handler_mapping[provider]()

# Always add TRAVELTIME_API handler
handlers[TRAVELTIME_API] = create_traveltime_handler()

return handlers
Loading