Skip to content

Commit

Permalink
Marking job complete if mark_complete is true
Browse files Browse the repository at this point in the history
  • Loading branch information
xop5 committed Jan 2, 2025
1 parent fe2326b commit d7e0ae6
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
3 changes: 0 additions & 3 deletions cfa_azure/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,6 @@ def add_job(
pool_name: str | None = None,
save_logs_to_blob: str | None = None,
logs_folder: str | None = None,
end_job_on_task_failure: bool = False,
task_retries: int = 0,
mark_complete_after_tasks_run: bool = False,
) -> None:
Expand All @@ -1056,7 +1055,6 @@ def add_job(
pool_name (str|None): pool to use for job. If None, will used self.pool_name from client. Default None.
save_logs_to_blob (str): the name of the blob container. Must be mounted to the pool. Default None for no saving.
logs_folder (str): the folder structure to use when saving logs to blob. Default None will save to /stdout_stderr/ folder in specified blob container.
end_job_on_task_failure (bool): whether to end the job if a task fails. Default False.
task_retries (int): number of times to retry a task that fails. Default 0.
mark_complete_after_tasks_run (bool): whether to mark the job as completed when all tasks finish running. Default False.
"""
Expand Down Expand Up @@ -1089,7 +1087,6 @@ def add_job(
helpers.add_job(
job_id=job_id_r,
pool_id=p_name,
end_job_on_task_failure=end_job_on_task_failure,
batch_client=self.batch_client,
task_retries=task_retries,
mark_complete=mark_complete_after_tasks_run,
Expand Down
17 changes: 13 additions & 4 deletions cfa_azure/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
JobConstraints,
OnAllTasksComplete,
OnTaskFailure,
MetadataItem
)
from azure.containerregistry import ContainerRegistryClient
from azure.core.exceptions import HttpResponseError
Expand Down Expand Up @@ -728,7 +729,6 @@ def list_nodes_by_pool(pool_name: str, config: dict, node_state: str = None):
def add_job(
job_id: str,
pool_id: str,
end_job_on_task_failure: bool,
batch_client: object,
task_retries: int = 0,
mark_complete: bool = False,
Expand All @@ -738,7 +738,6 @@ def add_job(
Args:
job_id (str): name of the job to run
pool_id (str): name of pool
end_job_on_task_failure (bool): whether to end a running job if a task fails
batch_client (object): batch client object
task_retries (int): number of times to retry the task if it fails. Default 3.
mark_complete (bool): whether to mark the job complete after tasks finish running. Default False.
Expand All @@ -758,6 +757,7 @@ def add_job(
on_all_tasks_complete=on_all_tasks_complete,
on_task_failure=OnTaskFailure.perform_exit_options_job_action,
constraints=job_constraints,
metadata=[MetadataItem(name="mark_complete", value=mark_complete)]
)
logger.debug("Attempting to add job.")
try:
Expand Down Expand Up @@ -833,8 +833,17 @@ def add_task_to_job(
logger.debug("Adding task dependency.")
task_deps = batchmodels.TaskDependencies(task_ids=depends_on)

job_action = JobAction.none
if check_job_exists(job_id, batch_client):
job_details = batch_client.job.get(job_id)
if job_details and job_details.metadata:
for metadata in job_details.metadata:
if metadata.name == "mark_complete" and metadata.value == True:
job_action = JobAction.terminate
break

no_exit_options = ExitOptions(
dependency_action=DependencyAction.satisfy, job_action=JobAction.none
dependency_action=DependencyAction.satisfy, job_action=job_action
)
if run_dependent_tasks_on_fail:
exit_conditions = ExitConditions(
Expand All @@ -849,7 +858,7 @@ def add_task_to_job(
else:
terminate_exit_options = ExitOptions(
dependency_action=DependencyAction.block,
job_action=JobAction.terminate,
job_action=job_action,
)
exit_conditions = ExitConditions(
exit_codes=[
Expand Down
6 changes: 6 additions & 0 deletions tests/fake_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def delete(self, *args):

def add(self, job):
return True

def get(self, job):
return True

def list(self):
return []

class FakeTag:
def __init__(self, tag):
Expand Down
6 changes: 2 additions & 4 deletions tests/helpers_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,7 @@ def test_add_job(self, mock_logger):
cfa_azure.helpers.add_job(
job_id,
FAKE_BATCH_POOL,
batch_client=batch_client,
end_job_on_task_failure=False,
batch_client=batch_client
)
mock_logger.info.assert_called_with(
f"Job '{job_id}' created successfully."
Expand All @@ -740,8 +739,7 @@ def test_add_job_task_failure(self, mock_logger):
cfa_azure.helpers.add_job(
job_id,
FAKE_BATCH_POOL,
batch_client=batch_client,
end_job_on_task_failure=False,
batch_client=batch_client
)
mock_logger.debug.assert_called_with("Attempting to add job.")

Expand Down

0 comments on commit d7e0ae6

Please sign in to comment.