diff --git a/README.md b/README.md index b0921c3..d1142ce 100644 --- a/README.md +++ b/README.md @@ -37,44 +37,26 @@ pip install traveltime-google-comparison ``` ## Setup -Provide credentials for the APIs via environment variables. - -For Google Maps API: - -```bash -export GOOGLE_API_KEY=[Your Google Maps API Key] -``` - -For TomTom API: - -```bash -export TOMTOM_API_KEY=[Your TomTom API Key] -``` - -For HERE API: - -```bash -export HERE_API_KEY=[Your HERE API Key] -``` - -For Mapbox API: - -```bash -export MAPBOX_API_KEY=[Your Mapbox API Key] -``` - -For OpenRoutes API: - -```bash -export OPENROUTES_API_KEY=[Your OpenRoutes API Key] -``` - -For OSRM API: OSRM does not require a key. - -For TravelTime API: -```bash -export TRAVELTIME_APP_ID=[Your TravelTime App ID] -export TRAVELTIME_API_KEY=[Your TravelTime API Key] +Provide credentials and desired max requests per minute for the APIs inside the `config.json` file. +You can also disable unwanted APIs by changing the `enabled` value to `false`. + +```json +{ + "traveltime": { + "app-id": "", + "api-key": "", + "max-rpm": "60" + }, + "api-providers": [ + { + "name": "google", + "enabled": true, + "api-key": "", + "max-rpm": "60" + }, + ...other providers + ] +} ``` ## Usage @@ -104,23 +86,8 @@ Required arguments: - `--time-zone-id [Time zone ID]`: non-abbreviated time zone identifier in which the time values are specified. For example: `Europe/London`. For more information, see [here](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). - - Optional arguments: -- `--google-max-rpm [int]`: Set max number of parallel requests sent to Google API per minute. Default is 60. - It is enforced on per-second basis, to avoid bursts. -- `--tomtom-max-rpm [int]`: Set max number of parallel requests sent to TomTom API per minute. Default is 60. - It is enforced on per-second basis, to avoid bursts. -- `--mapbox-max-rpm [int]`: Set max number of parallel requests sent to Mapbox API per minute. Default is 60. - It is enforced on per-second basis, to avoid bursts. -- `--here-max-rpm [int]`: Set max number of parallel requests sent to HERE API per minute. Default is 60. - It is enforced on per-second basis, to avoid bursts. -- `--osrm-max-rpm [int]`: Set max number of parallel requests sent to OSRM API per minute. Default is 60. - It is enforced on per-second basis, to avoid bursts. -- `--openroutes-max-rpm [int]`: Set max number of parallel requests sent to OpenRoutes API per minute. Default is 60. - It is enforced on per-second basis, to avoid bursts. -- `--traveltime-max-rpm [int]`: Set max number of parallel requests sent to TravelTime API per minute. Default is 60. - It is enforced on per-second basis, to avoid bursts. +- `--config [Config file path]`: Path to the config file. Default - ./config.json Example: diff --git a/config.json b/config.json new file mode 100644 index 0000000..159a429 --- /dev/null +++ b/config.json @@ -0,0 +1,45 @@ +{ + "traveltime": { + "app-id": "", + "api-key": "", + "max-rpm": "60" + }, + "api-providers": [ + { + "name": "google", + "enabled": true, + "api-key": "", + "max-rpm": "60" + }, + { + "name": "tomtom", + "enabled": true, + "api-key": "", + "max-rpm": "60" + }, + { + "name": "here", + "enabled": true, + "api-key": "", + "max-rpm": "60" + }, + { + "name": "mapbox", + "enabled": true, + "api-key": "", + "max-rpm": "60" + }, + { + "name": "osrm", + "enabled": true, + "api-key": "not-needed!", + "max-rpm": "60" + }, + { + "name": "openroutes", + "enabled": true, + "api-key": "", + "max-rpm": "20" + } + ] +} diff --git a/src/traveltime_google_comparison/analysis.py b/src/traveltime_google_comparison/analysis.py index db234b0..023e1e1 100644 --- a/src/traveltime_google_comparison/analysis.py +++ b/src/traveltime_google_comparison/analysis.py @@ -1,6 +1,5 @@ import logging from dataclasses import dataclass -from typing import List from pandas import DataFrame @@ -9,6 +8,7 @@ TRAVELTIME_API, get_capitalized_provider_name, ) +from traveltime_google_comparison.config import Providers def absolute_error(api_provider: str) -> str: @@ -26,17 +26,16 @@ class QuantileErrorResult: def log_results( - results_with_differences: DataFrame, quantile: float, api_providers: List[str] + results_with_differences: DataFrame, quantile: float, api_providers: Providers ): - for provider in api_providers: - capitalized_provider = get_capitalized_provider_name(provider) + for provider in api_providers.competitors: + name = provider.name + capitalized_provider = get_capitalized_provider_name(name) logging.info( f"Mean relative error compared to {capitalized_provider} " - f"API: {results_with_differences[relative_error(provider)].mean():.2f}%" - ) - quantile_errors = calculate_quantiles( - results_with_differences, quantile, provider + f"API: {results_with_differences[relative_error(name)].mean():.2f}%" ) + quantile_errors = calculate_quantiles(results_with_differences, quantile, name) logging.info( f"{int(quantile * 100)}% of TravelTime results differ from {capitalized_provider} API " f"by less than {int(quantile_errors.relative_error)}%" @@ -44,13 +43,14 @@ def log_results( def format_results_for_csv( - results_with_differences: DataFrame, api_providers: List[str] + results_with_differences: DataFrame, api_providers: Providers ) -> DataFrame: formatted_results = results_with_differences.copy() - for provider in api_providers: - formatted_results = formatted_results.drop(columns=[absolute_error(provider)]) - relative_error_col = relative_error(provider) + for provider in api_providers.competitors: + name = provider.name + formatted_results = formatted_results.drop(columns=[absolute_error(name)]) + relative_error_col = relative_error(name) formatted_results[relative_error_col] = formatted_results[ relative_error_col ].astype(int) @@ -59,7 +59,7 @@ def format_results_for_csv( def run_analysis( - results: DataFrame, output_file: str, quantile: float, api_providers: List[str] + results: DataFrame, output_file: str, quantile: float, api_providers: Providers ): results_with_differences = calculate_differences(results, api_providers) log_results(results_with_differences, quantile, api_providers) @@ -71,21 +71,22 @@ def run_analysis( formatted_results.to_csv(output_file, index=False) -def calculate_differences(results: DataFrame, api_providers: List[str]) -> DataFrame: +def calculate_differences(results: DataFrame, api_providers: Providers) -> DataFrame: results_with_differences = results.copy() - for provider in api_providers: - absolute_error_col = absolute_error(provider) - relative_error_col = relative_error(provider) + for provider in api_providers.competitors: + name = provider.name + absolute_error_col = absolute_error(name) + relative_error_col = relative_error(name) results_with_differences[absolute_error_col] = abs( - results[Fields.TRAVEL_TIME[provider]] + results[Fields.TRAVEL_TIME[name]] - results[Fields.TRAVEL_TIME[TRAVELTIME_API]] ) results_with_differences[relative_error_col] = ( results_with_differences[absolute_error_col] - / results_with_differences[Fields.TRAVEL_TIME[provider]] + / results_with_differences[Fields.TRAVEL_TIME[name]] * 100 ) @@ -95,13 +96,13 @@ def calculate_differences(results: DataFrame, api_providers: List[str]) -> DataF def calculate_quantiles( results_with_differences: DataFrame, quantile: float, - api_provider: str, + api_provider_name: str, ) -> QuantileErrorResult: quantile_absolute_error = results_with_differences[ - absolute_error(api_provider) + absolute_error(api_provider_name) ].quantile(quantile, "higher") quantile_relative_error = results_with_differences[ - relative_error(api_provider) + relative_error(api_provider_name) ].quantile(quantile, "higher") return QuantileErrorResult( int(quantile_absolute_error), int(quantile_relative_error) diff --git a/src/traveltime_google_comparison/collect.py b/src/traveltime_google_comparison/collect.py index 2fc746e..9ec0e08 100644 --- a/src/traveltime_google_comparison/collect.py +++ b/src/traveltime_google_comparison/collect.py @@ -13,6 +13,7 @@ from traveltime_google_comparison.config import Mode from traveltime_google_comparison.requests.base_handler import BaseRequestHandler + GOOGLE_API = "google" TOMTOM_API = "tomtom" HERE_API = "here" @@ -132,7 +133,10 @@ def generate_tasks( async def collect_travel_times( - args, data, request_handlers: Dict[str, BaseRequestHandler], providers: List[str] + args, + data, + request_handlers: Dict[str, BaseRequestHandler], + provider_names: List[str], ) -> DataFrame: timezone = pytz.timezone(args.time_zone_id) localized_start_datetime = localize_datetime(args.date, args.start_time, timezone) @@ -144,28 +148,16 @@ async def collect_travel_times( tasks = generate_tasks(data, time_instants, request_handlers, mode=Mode.DRIVING) capitalized_providers_str = ", ".join( - [get_capitalized_provider_name(provider) for provider in providers] - ) - logger.info( - f"Sending {len(tasks)} requests to {capitalized_providers_str} and TravelTime APIs" + [get_capitalized_provider_name(provider) for provider in provider_names] ) + logger.info(f"Sending {len(tasks)} requests to {capitalized_providers_str} APIs") results = await asyncio.gather(*tasks) results_df = pd.DataFrame(results) deduplicated = results_df.groupby( [Fields.ORIGIN, Fields.DESTINATION, Fields.DEPARTURE_TIME], as_index=False - ).agg( - { - Fields.TRAVEL_TIME[GOOGLE_API]: "first", - Fields.TRAVEL_TIME[TOMTOM_API]: "first", - Fields.TRAVEL_TIME[HERE_API]: "first", - Fields.TRAVEL_TIME[OSRM_API]: "first", - Fields.TRAVEL_TIME[OPENROUTES_API]: "first", - Fields.TRAVEL_TIME[MAPBOX_API]: "first", - Fields.TRAVEL_TIME[TRAVELTIME_API]: "first", - } - ) + ).agg({Fields.TRAVEL_TIME[provider]: "first" for provider in provider_names}) deduplicated.to_csv(args.output, index=False) return deduplicated diff --git a/src/traveltime_google_comparison/config.py b/src/traveltime_google_comparison/config.py index 1ba69d8..2c772f5 100644 --- a/src/traveltime_google_comparison/config.py +++ b/src/traveltime_google_comparison/config.py @@ -1,11 +1,13 @@ import argparse -import os +from dataclasses import dataclass from enum import Enum +from typing import List import pandas +from traveltimepy.http import json from traveltime_google_comparison.requests.traveltime_credentials import ( - TravelTimeCredentials, + Credentials, ) DEFAULT_GOOGLE_RPM = 60 @@ -28,6 +30,22 @@ pandas.set_option("display.width", None) +@dataclass +class Provider: + name: str + max_rpm: int + credentials: Credentials + + +@dataclass +class Providers: + base: Provider + competitors: List[Provider] + + def all_names(self) -> List[str]: + return [self.base.name] + [competitor.name for competitor in self.competitors] + + class Mode(Enum): DRIVING = "driving" PUBLIC_TRANSPORT = "public_transport" @@ -35,7 +53,7 @@ class Mode(Enum): def parse_args(): parser = argparse.ArgumentParser( - description="Fetch and compare travel times from Google Directions API and TravelTime Routes API" + description="Fetch and compare travel times from TravelTime Routes API and it's competitors" ) parser.add_argument("--input", required=True, help="Input CSV file path") parser.add_argument("--output", required=True, help="Output CSV file path") @@ -51,55 +69,11 @@ def parse_args(): help="Non-abbreviated time zone identifier e.g. Europe/London", ) parser.add_argument( - "--google-max-rpm", - required=False, - type=int, - default=DEFAULT_GOOGLE_RPM, - help="Maximum number of requests sent to Google API per minute", - ) - parser.add_argument( - "--tomtom-max-rpm", - required=False, - type=int, - default=DEFAULT_TOMTOM_RPM, - help="Maximum number of requests sent to TomTom API per minute", - ) - parser.add_argument( - "--here-max-rpm", - required=False, - type=int, - default=DEFAULT_HERE_RPM, - help="Maximum number of requests sent to HERE API per minute", - ) - parser.add_argument( - "--osrm-max-rpm", - required=False, - type=int, - default=DEFAULT_OSRM_RPM, - help="Maximum number of requests sent to OSRM API per minute", - ) - parser.add_argument( - "--openroutes-max-rpm", - required=False, - type=int, - default=DEFAULT_OPENROUTES_RPM, - help="Maximum number of requests sent to OpenRoutes API per minute", - ) - parser.add_argument( - "--mapbox-max-rpm", + "--config", required=False, - type=int, - default=DEFAULT_MAPBOX_RPM, - help="Maximum number of requests sent to Mapbox API per minute", + default="./config.json", + help="Path to your config file. Default - ./config.json", ) - parser.add_argument( - "--traveltime-max-rpm", - required=False, - type=int, - default=DEFAULT_TRAVELTIME_RPM, - help="Maximum number of requests sent to TravelTime API per minute", - ) - parser.add_argument( "--skip-data-gathering", action=argparse.BooleanOptionalAction, @@ -111,54 +85,40 @@ def parse_args(): return parser.parse_args() -def retrieve_google_api_key(): - google_api_key = os.environ.get(GOOGLE_API_KEY_VAR_NAME) - - if not google_api_key: - raise ValueError(f"{GOOGLE_API_KEY_VAR_NAME} not set in environment variables.") - return google_api_key +def parse_json_to_providers(json_data: str) -> Providers: + data = json.loads(json_data) + # Parse TravelTime (base provider) + traveltime_data = data["traveltime"] + base_provider = Provider( + name="traveltime", + max_rpm=int(traveltime_data["max-rpm"]), + credentials=Credentials( + app_id=traveltime_data["app-id"], api_key=traveltime_data["api-key"] + ), + ) -def retrieve_openroutes_api_key(): - openroutes_api_key = os.environ.get(OPENROUTES_API_KEY_VAR_NAME) - - if not openroutes_api_key: + # Parse competitor providers + competitors = [] + for provider_data in data["api-providers"]: + enabled = provider_data["enabled"] + if enabled: + competitor = Provider( + name=provider_data["name"], + max_rpm=int(provider_data["max-rpm"]), + credentials=Credentials(api_key=provider_data["api-key"]), + ) + competitors.append(competitor) + + if len(competitors) == 0: raise ValueError( - f"{OPENROUTES_API_KEY_VAR_NAME} not set in environment variables." + "There should be at least one enabled API provider that's not TravelTime." ) - return openroutes_api_key - - -def retrieve_tomtom_api_key(): - tomtom_api_key = os.environ.get(TOMTOM_API_KEY_VAR_NAME) - if not tomtom_api_key: - raise ValueError(f"{TOMTOM_API_KEY_VAR_NAME} not set in environment variables.") - return tomtom_api_key + return Providers(base=base_provider, competitors=competitors) -def retrieve_here_api_key(): - here_api_key = os.environ.get(HERE_API_KEY_VAR_NAME) - - if not here_api_key: - raise ValueError(f"{HERE_API_KEY_VAR_NAME} not set in environment variables.") - return here_api_key - - -def retrieve_mapbox_api_key(): - mapbox_api_key = os.environ.get(MAPBOX_API_KEY_VAR_NAME) - - if not mapbox_api_key: - raise ValueError(f"{MAPBOX_API_KEY_VAR_NAME} not set in environment variables.") - return mapbox_api_key - - -def retrieve_traveltime_credentials() -> TravelTimeCredentials: - app_id = os.environ.get(TRAVELTIME_APP_ID_VAR_NAME) - api_key = os.environ.get(TRAVELTIME_API_KEY_VAR_NAME) - - if not (app_id and api_key): - raise ValueError( - "TravelTime API credentials are missing from environment variables." - ) - return TravelTimeCredentials(app_id, api_key) +def parse_config(file_path: str): + with open(file_path, "r") as file: # letting it crash if this fails + content = file.read() + return parse_json_to_providers(content) diff --git a/src/traveltime_google_comparison/main.py b/src/traveltime_google_comparison/main.py index 4630e7d..5b5eb03 100644 --- a/src/traveltime_google_comparison/main.py +++ b/src/traveltime_google_comparison/main.py @@ -6,16 +6,8 @@ from traveltime_google_comparison import collect from traveltime_google_comparison import config from traveltime_google_comparison.analysis import run_analysis -from traveltime_google_comparison.collect import ( - OSRM_API, - OPENROUTES_API, - HERE_API, - MAPBOX_API, - Fields, - GOOGLE_API, - TRAVELTIME_API, - TOMTOM_API, -) +from traveltime_google_comparison.config import parse_config +from traveltime_google_comparison.collect import Fields from traveltime_google_comparison.requests import factory logging.basicConfig( @@ -28,8 +20,13 @@ async def run(): - providers = [GOOGLE_API, TOMTOM_API, HERE_API, MAPBOX_API, OSRM_API, OPENROUTES_API] args = config.parse_args() + config_path = args.config + + # Get all providers that should be tested against TravelTime + providers = parse_config(config_path) + all_provider_names = providers.all_names() + csv = pd.read_csv( args.input, usecols=[Fields.ORIGIN, Fields.DESTINATION] ).drop_duplicates() @@ -38,15 +35,7 @@ async def run(): logger.info("Provided input file is empty. Exiting.") return - request_handlers = factory.initialize_request_handlers( - args.google_max_rpm, - args.tomtom_max_rpm, - args.here_max_rpm, - args.osrm_max_rpm, - args.openroutes_max_rpm, - args.mapbox_max_rpm, - args.traveltime_max_rpm, - ) + request_handlers = factory.initialize_request_handlers(providers) if args.skip_data_gathering: travel_times_df = pd.read_csv( args.input, @@ -54,27 +43,20 @@ async def run(): Fields.ORIGIN, Fields.DESTINATION, Fields.DEPARTURE_TIME, - Fields.TRAVEL_TIME[GOOGLE_API], - Fields.TRAVEL_TIME[TOMTOM_API], - Fields.TRAVEL_TIME[HERE_API], - Fields.TRAVEL_TIME[OSRM_API], - Fields.TRAVEL_TIME[OPENROUTES_API], - Fields.TRAVEL_TIME[MAPBOX_API], - Fields.TRAVEL_TIME[TRAVELTIME_API], - ], + ] # base fields + + [Fields.TRAVEL_TIME[provider] for provider in all_provider_names], ) else: travel_times_df = await collect.collect_travel_times( - args, csv, request_handlers, providers + args, csv, request_handlers, all_provider_names ) + filtered_travel_times_df = travel_times_df.loc[ - travel_times_df[Fields.TRAVEL_TIME[GOOGLE_API]].notna() - & travel_times_df[Fields.TRAVEL_TIME[TOMTOM_API]].notna() - & travel_times_df[Fields.TRAVEL_TIME[HERE_API]].notna() - & travel_times_df[Fields.TRAVEL_TIME[OSRM_API]].notna() - & travel_times_df[Fields.TRAVEL_TIME[OPENROUTES_API]].notna() - & travel_times_df[Fields.TRAVEL_TIME[MAPBOX_API]].notna() - & travel_times_df[Fields.TRAVEL_TIME[TRAVELTIME_API]].notna(), + travel_times_df[ + [Fields.TRAVEL_TIME[provider] for provider in all_provider_names] + ] + .notna() + .all(axis=1), :, ] diff --git a/src/traveltime_google_comparison/requests/factory.py b/src/traveltime_google_comparison/requests/factory.py index 92c95e1..0a8b3ed 100644 --- a/src/traveltime_google_comparison/requests/factory.py +++ b/src/traveltime_google_comparison/requests/factory.py @@ -9,14 +9,7 @@ GOOGLE_API, OPENROUTES_API, ) -from traveltime_google_comparison.config import ( - retrieve_google_api_key, - retrieve_here_api_key, - retrieve_mapbox_api_key, - retrieve_tomtom_api_key, - retrieve_openroutes_api_key, - retrieve_traveltime_credentials, -) +from traveltime_google_comparison.config import Provider, Providers from traveltime_google_comparison.requests.base_handler import BaseRequestHandler from traveltime_google_comparison.requests.google_handler import GoogleRequestHandler from traveltime_google_comparison.requests.tomtom_handler import TomTomRequestHandler @@ -31,31 +24,45 @@ ) -def initialize_request_handlers( - google_max_rpm, - tomtom_max_rpm, - here_max_rpm, - osrm_max_rpm, - openroutes_max_rpm, - mapbox_max_rpm, - traveltime_max_rpm, -) -> Dict[str, BaseRequestHandler]: - google_api_key = retrieve_google_api_key() - tomtom_api_key = retrieve_tomtom_api_key() - here_api_key = retrieve_here_api_key() - mapbox_api_key = retrieve_mapbox_api_key() - openroutes_api_key = retrieve_openroutes_api_key() - credentials = retrieve_traveltime_credentials() - return { - GOOGLE_API: GoogleRequestHandler(google_api_key, google_max_rpm), - TOMTOM_API: TomTomRequestHandler(tomtom_api_key, tomtom_max_rpm), - HERE_API: HereRequestHandler(here_api_key, here_max_rpm), - OSRM_API: OSRMRequestHandler("", osrm_max_rpm), - OPENROUTES_API: OpenRoutesRequestHandler( - openroutes_api_key, openroutes_max_rpm - ), - MAPBOX_API: MapboxRequestHandler(mapbox_api_key, mapbox_max_rpm), - TRAVELTIME_API: TravelTimeRequestHandler( - credentials.app_id, credentials.api_key, traveltime_max_rpm - ), +def initialize_request_handlers(providers: Providers) -> Dict[str, BaseRequestHandler]: + def create_google_handler(provider: Provider): + return GoogleRequestHandler(provider.credentials.api_key, provider.max_rpm) + + def create_tomtom_handler(provider: Provider): + return TomTomRequestHandler(provider.credentials.api_key, provider.max_rpm) + + def create_here_handler(provider: Provider): + return HereRequestHandler(provider.credentials.api_key, provider.max_rpm) + + def create_osrm_handler(provider: Provider): + return OSRMRequestHandler("", provider.max_rpm) + + def create_openroutes_handler(provider: Provider): + return OpenRoutesRequestHandler(provider.credentials.api_key, provider.max_rpm) + + def create_mapbox_handler(provider: Provider): + return MapboxRequestHandler(provider.credentials.api_key, provider.max_rpm) + + def create_traveltime_handler(provider: Provider): + return TravelTimeRequestHandler( + provider.credentials.app_id, provider.credentials.api_key, provider.max_rpm + ) + + handler_mapping = { + GOOGLE_API: create_google_handler, + TOMTOM_API: create_tomtom_handler, + HERE_API: create_here_handler, + OSRM_API: create_osrm_handler, + OPENROUTES_API: create_openroutes_handler, + MAPBOX_API: create_mapbox_handler, } + + handlers = {} + for competitor in providers.competitors: + if competitor.name in handler_mapping: + handlers[competitor.name] = handler_mapping[competitor.name](competitor) + + # Always add TRAVELTIME_API handler + handlers[TRAVELTIME_API] = create_traveltime_handler(providers.base) + + return handlers diff --git a/src/traveltime_google_comparison/requests/traveltime_credentials.py b/src/traveltime_google_comparison/requests/traveltime_credentials.py index c333d4e..f43d1e8 100644 --- a/src/traveltime_google_comparison/requests/traveltime_credentials.py +++ b/src/traveltime_google_comparison/requests/traveltime_credentials.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from typing import Optional @dataclass -class TravelTimeCredentials: - app_id: str +class Credentials: api_key: str + app_id: Optional[str] = None diff --git a/test/test_analysis.py b/test/test_analysis.py index e9df99f..14939bb 100644 --- a/test/test_analysis.py +++ b/test/test_analysis.py @@ -7,10 +7,23 @@ relative_error, ) from traveltime_google_comparison.collect import GOOGLE_API, TRAVELTIME_API, Fields +from traveltime_google_comparison.config import Provider, Providers +from traveltime_google_comparison.requests.traveltime_credentials import ( + Credentials, +) ABSOLUTE_ERROR_GOOGLE = absolute_error(GOOGLE_API) RELATIVE_ERROR_GOOGLE = relative_error(GOOGLE_API) +PROVIDERS = Providers( + base=Provider( + name="traveltime", + max_rpm=60, + credentials=Credentials(app_id="test", api_key="test"), + ), + competitors=[Provider(name="google", max_rpm=60, credentials=Credentials("test"))], +) + def test_calculate_differences_calculate_absolute_and_relative_differences(): data = { @@ -18,7 +31,7 @@ def test_calculate_differences_calculate_absolute_and_relative_differences(): Fields.TRAVEL_TIME[TRAVELTIME_API]: [90, 210, 290], } df = pd.DataFrame(data) - result_df = calculate_differences(df, [GOOGLE_API]) + result_df = calculate_differences(df, PROVIDERS) assert result_df[ABSOLUTE_ERROR_GOOGLE].tolist() == [10, 10, 10] assert result_df[RELATIVE_ERROR_GOOGLE].tolist() == [10.0, 5.0, 10.0 / 3] @@ -30,7 +43,7 @@ def test_calculate_differences_survives_division_by_zero(): Fields.TRAVEL_TIME[TRAVELTIME_API]: [90, 210, 290], } df = pd.DataFrame(data) - result_df = calculate_differences(df, [GOOGLE_API]) + result_df = calculate_differences(df, PROVIDERS) assert result_df[ABSOLUTE_ERROR_GOOGLE].tolist() == [90, 10, 10] assert result_df[RELATIVE_ERROR_GOOGLE].tolist() == [float("inf"), 5.0, 10.0 / 3] diff --git a/test/test_config.py b/test/test_config.py index dda7274..828ad52 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -1,54 +1,106 @@ import pytest from traveltime_google_comparison.config import ( - TRAVELTIME_APP_ID_VAR_NAME, - TRAVELTIME_API_KEY_VAR_NAME, - retrieve_traveltime_credentials, + Provider, + Providers, + parse_json_to_providers, ) from traveltime_google_comparison.requests.traveltime_credentials import ( - TravelTimeCredentials, + Credentials, ) -def test_retrieve_traveltime_credentials_valid(monkeypatch): - monkeypatch.setenv(TRAVELTIME_APP_ID_VAR_NAME, "sample_app_id") - monkeypatch.setenv(TRAVELTIME_API_KEY_VAR_NAME, "sample_api_key") +def test_json_config_parse(): + json = """ + { + "traveltime": { + "app-id": "", + "api-key": "", + "max-rpm": "60" + }, + "api-providers": [ + { + "name": "google", + "enabled": true, + "api-key": "", + "max-rpm": "60" + }, + { + "name": "tomtom", + "enabled": false, + "api-key": "", + "max-rpm": "30" + } + ] + } + """ - credentials = retrieve_traveltime_credentials() + providers = parse_json_to_providers(json) - assert isinstance(credentials, TravelTimeCredentials) - assert credentials.app_id == "sample_app_id" - assert credentials.api_key == "sample_api_key" + assert providers == Providers( + base=Provider( + name="traveltime", + max_rpm=60, + credentials=Credentials(app_id="", api_key=""), + ), + competitors=[ + Provider( + name="google", max_rpm=60, credentials=Credentials("") + ) + ], + ) -def test_retrieve_traveltime_credentials_missing_app_id(monkeypatch): - monkeypatch.delenv(TRAVELTIME_APP_ID_VAR_NAME, raising=False) - monkeypatch.setenv(TRAVELTIME_API_KEY_VAR_NAME, "sample_api_key") +def test_json_config_parse_all_disabled_providers(): + json = """ + { + "traveltime": { + "app-id": "", + "api-key": "", + "max-rpm": "60" + }, + "api-providers": [ + { + "name": "google", + "enabled": false, + "api-key": "", + "max-rpm": "60" + }, + { + "name": "tomtom", + "enabled": false, + "api-key": "", + "max-rpm": "30" + } + ] + } + """ - with pytest.raises( - ValueError, - match="TravelTime API credentials are missing from environment variables.", - ): - retrieve_traveltime_credentials() + with pytest.raises(ValueError) as excinfo: + _ = parse_json_to_providers(json) + assert ( + str(excinfo.value) + == "There should be at least one enabled API provider that's not TravelTime." + ) -def test_retrieve_traveltime_credentials_missing_api_key(monkeypatch): - monkeypatch.setenv(TRAVELTIME_APP_ID_VAR_NAME, "sample_app_id") - monkeypatch.delenv(TRAVELTIME_API_KEY_VAR_NAME, raising=False) - with pytest.raises( - ValueError, - match="TravelTime API credentials are missing from environment variables.", - ): - retrieve_traveltime_credentials() +def test_json_config_parse_empty_providers(): + json = """ + { + "traveltime": { + "app-id": "", + "api-key": "", + "max-rpm": "60" + }, + "api-providers": [] + } + """ + with pytest.raises(ValueError) as excinfo: + _ = parse_json_to_providers(json) -def test_retrieve_traveltime_credentials_missing_both(monkeypatch): - monkeypatch.delenv(TRAVELTIME_APP_ID_VAR_NAME, raising=False) - monkeypatch.delenv(TRAVELTIME_API_KEY_VAR_NAME, raising=False) - - with pytest.raises( - ValueError, - match="TravelTime API credentials are missing from environment variables.", - ): - retrieve_traveltime_credentials() + assert ( + str(excinfo.value) + == "There should be at least one enabled API provider that's not TravelTime." + )