Skip to content

Commit

Permalink
Rename "try_number" increments that are unrelated to the airflow conc…
Browse files Browse the repository at this point in the history
…ept (#39317)

This avoids some confusion and false positives when looking for try_number mutation in the database.
  • Loading branch information
dstandish authored Apr 29, 2024
1 parent 6112745 commit 4fa9fe1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
14 changes: 8 additions & 6 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,21 +494,23 @@ def poll_query_status(
:param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
:param max_polling_attempts: Number of times to poll for query state before function exits
"""
try_number = 1
poll_attempt = 1
while True:
query_state = self.check_query_status(job_id)
if query_state in self.TERMINAL_STATES:
self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
self.log.info(
"Try %s: Query execution completed. Final state is %s", poll_attempt, query_state
)
return query_state
if query_state is None:
self.log.info("Try %s: Invalid query state. Retrying again", try_number)
self.log.info("Try %s: Invalid query state. Retrying again", poll_attempt)
else:
self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
self.log.info("Try %s: Query is still in non-terminal state - %s", poll_attempt, query_state)
if (
max_polling_attempts and try_number >= max_polling_attempts
max_polling_attempts and poll_attempt >= max_polling_attempts
): # Break loop if max_polling_attempts reached
return query_state
try_number += 1
poll_attempt += 1
time.sleep(poll_interval)

def stop_query(self, job_id: str) -> dict:
Expand Down
14 changes: 7 additions & 7 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def run_duration() -> float:
def run_duration() -> float:
return time.monotonic() - start_monotonic

try_number = 1
poke_count = 1
log_dag_id = self.dag.dag_id if self.has_dag() else ""

xcom_value = None
Expand Down Expand Up @@ -312,7 +312,7 @@ def run_duration() -> float:
else:
raise AirflowSensorTimeout(message)
if self.reschedule:
next_poke_interval = self._get_next_poke_interval(started_at, run_duration, try_number)
next_poke_interval = self._get_next_poke_interval(started_at, run_duration, poke_count)
reschedule_date = timezone.utcnow() + timedelta(seconds=next_poke_interval)
if _is_metadatabase_mysql() and reschedule_date > _MYSQL_TIMESTAMP_MAX:
raise AirflowSensorTimeout(
Expand All @@ -321,8 +321,8 @@ def run_duration() -> float:
)
raise AirflowRescheduleException(reschedule_date)
else:
time.sleep(self._get_next_poke_interval(started_at, run_duration, try_number))
try_number += 1
time.sleep(self._get_next_poke_interval(started_at, run_duration, poke_count))
poke_count += 1
self.log.info("Success criteria met. Exiting.")
return xcom_value

Expand All @@ -338,17 +338,17 @@ def _get_next_poke_interval(
self,
started_at: datetime.datetime | float,
run_duration: Callable[[], float],
try_number: int,
poke_count: int,
) -> float:
"""Use similar logic which is used for exponential backoff retry delay for operators."""
if not self.exponential_backoff:
return self.poke_interval

# The value of min_backoff should always be greater than or equal to 1.
min_backoff = max(int(self.poke_interval * (2 ** (try_number - 2))), 1)
min_backoff = max(int(self.poke_interval * (2 ** (poke_count - 2))), 1)

run_hash = int(
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode()).hexdigest(),
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{poke_count}".encode()).hexdigest(),
16,
)
modded_hash = min_backoff + run_hash % min_backoff
Expand Down

0 comments on commit 4fa9fe1

Please sign in to comment.