Skip to content

Commit

Permalink
Rewrite the conditional tagging logic using hcubes_intdiff2 and apply…
Browse files Browse the repository at this point in the history
… it on the constraints-intersected request(s)
  • Loading branch information
ecmwf-cobarzan committed Dec 4, 2024
1 parent a86b339 commit 8473f0f
Showing 1 changed file with 27 additions and 63 deletions.
90 changes: 27 additions & 63 deletions cads_adaptors/adaptors/cds.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import pathlib
import datetime
import re
from copy import deepcopy
from random import randint
from typing import Any, BinaryIO
Expand All @@ -10,9 +8,11 @@
from cads_adaptors.adaptors import AbstractAdaptor, Context, Request
from cads_adaptors.exceptions import InvalidRequest
from cads_adaptors.tools.general import ensure_list
from cads_adaptors.tools.hcube_tools import hcubes_intdiff2
from cads_adaptors.validation import enforce



class AbstractCdsAdaptor(AbstractAdaptor):
resources = {"CADS_ADAPTORS": 1}
adaptor_schema: dict[str, Any] = {}
Expand Down Expand Up @@ -142,54 +142,17 @@ def pre_mapping_modifications(self, request: dict[str, Any]) -> dict[str, Any]:

return request

def ensure_dateranges(self, strings):
dateranges = []
for string in strings:
dates = re.split("[;/]", string)
if len(dates) == 1:
dates *= 2
dateranges.append(dates)
return dateranges

def instantiate_dynamic_daterange(self, string: str, today: datetime):
dates = re.split("[;/]", string)
if len(dates) == 1:
dates *= 2
for i,date in enumerate(dates):
if date.startswith("current"):
diff = date.replace("current","")
diff = int(diff) if diff else 0
date = today + datetime.timedelta(diff)
dates[i] = date.strftime("%Y-%m-%d")
return f"{dates[0]}/{dates[1]}"

def preprocess_conditions(self, conditions):
today = datetime.datetime.now(datetime.timezone.utc)
for condition in conditions:
if "date" in condition:
condition["date"] = self.instantiate_dynamic_daterange(condition["date"], today)

def dateranges_in(self, contained, container):
container_interval = re.split("[;/]", container)
contained_intervals = self.ensure_dateranges(contained)
for contained_interval in contained_intervals:
if not (container_interval[0] <= contained_interval[0] and contained_interval[1] <= container_interval[1]):
return False
return True

def satisfy_condition(self, request: dict[str, Any], condition: dict[str, Any]):
for key in condition:
if key == "date":
if not self.dateranges_in(request[key], condition[key]):
return False
elif not set(ensure_list(request[key])) <= set(ensure_list(condition[key])):
return False
return True

def satisfy_conditions(self, request: dict[str, Any], conditions: list[dict[str, Any]]):
for condition in conditions:
if self.satisfy_condition(request, condition):
return True
def ensure_list_values(self, dicts):
for d in dicts:
for key in d:
d[key] = ensure_list(d[key])

def satisfy_conditions(self, requests: list[dict[str, list[Any]]], conditions: list[dict[str, list[Any]]]):
try:
_, d12, _ = hcubes_intdiff2(requests, conditions)
return not d12
except Exception as e:
return False

def normalise_request(self, request: Request) -> Request:
"""
Expand Down Expand Up @@ -219,19 +182,6 @@ def normalise_request(self, request: Request) -> Request:
# Pre-mapping modifications
working_request = self.pre_mapping_modifications(deepcopy(request))

# Implement a request-level tagging system
try:
self.conditional_tagging = self.config.get("conditional_tagging", None)
if self.conditional_tagging is not None:
for tag in self.conditional_tagging:
conditions = self.conditional_tagging[tag]
self.preprocess_conditions(conditions)
if self.satisfy_conditions(request, conditions):
hidden_tag = f"__{tag}"
request[hidden_tag] = True
except Exception as e:
self.context.add_stdout(f"An error occured while attempting conditional tagging: {e!r}")

# If specified by the adaptor, intersect the request with the constraints.
# The intersected_request is a list of requests
if self.intersect_constraints_bool:
Expand All @@ -243,6 +193,20 @@ def normalise_request(self, request: Request) -> Request:
else:
self.intersected_requests = ensure_list(working_request)

# Implement a request-level tagging system
try:
self.conditional_tagging = self.config.get("conditional_tagging", None)
if self.conditional_tagging is not None:
self.ensure_list_values(self.intersected_requests)
for tag in self.conditional_tagging:
conditions = self.conditional_tagging[tag]
self.ensure_list_values(conditions)
if self.satisfy_conditions(self.intersected_requests, conditions):
hidden_tag = f"__{tag}"
request[hidden_tag] = True
except Exception as e:
self.context.add_stdout(f"An error occured while attempting conditional tagging: {e!r}")

# Map the list of requests
self.mapped_requests = [
self.apply_mapping(i_request) for i_request in self.intersected_requests
Expand Down

0 comments on commit 8473f0f

Please sign in to comment.