Skip to content

Commit

Permalink
Removed redundant arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
xop5 committed Jan 16, 2025
1 parent 0522e08 commit b9333ec
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 122 deletions.
6 changes: 1 addition & 5 deletions cfa_azure/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,11 +1242,7 @@ def monitor_job(self, job_id: str, timeout: str | None = None) -> None:
monitor = helpers.monitor_tasks(
job_id,
timeout,
self.batch_client,
self.resource_group_name,
self.account_name,
self.pool_name,
self.batch_mgmt_client,
self.batch_client
)
print(monitor)

Expand Down
142 changes: 25 additions & 117 deletions cfa_azure/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
logger = logging.getLogger(__name__)


def read_config(config_path: str = "./configuration.toml") -> dict:
def read_config(config_path: str = "./configuration.toml"):
"""takes in a path to a configuration toml file and returns it as a json object
Args:
Expand Down Expand Up @@ -235,7 +235,7 @@ def create_blob_containers(
"""creates the input and output Blob containers based on given names
Args:
blob_service_client (object): an instance of the Blob Service Client
blob_service_client (object): an instance of the Batch Management Client
input_container_name (str): user specified name for input container. Default is None.
output_container_name (str): user specified name for output container. Default is None.
"""
Expand Down Expand Up @@ -277,11 +277,9 @@ def get_batch_pool_json(
output_container_name (str): user specified name for input container
config (dict): config dictionary
autoscale_formula_path (str): path to the autoscale formula
autoscale_evaluation_interval (str): time period for autoscale evaluation. Default is 15 minutes.
fixedscale_resize_timeout (str): timeout for resizing fixed scale pools. Default is 15 minutes.
Returns:
dict: relevant information for Batch pool creation
json: relevant information for Batch pool creation
"""
logger.debug("Preparing batch pool configuration...")
# User-assigned identity for the pool
Expand Down Expand Up @@ -430,17 +428,6 @@ def update_pool(
account_name: str,
resource_group_name: str,
) -> dict:
"""
Args:
pool_name (str): name of pool to update
pool_parameters (dict): pool parameters dictionary
batch_mgmt_client (object): instance of BatchManagementClient object
account_name (str): name of Azure Account
resource_group_name (str): name of Resource Group in Azure
Returns:
dict: json of pool_id and updation_time
"""
print("Updating the pool...")

start_time = datetime.datetime.now()
Expand Down Expand Up @@ -585,13 +572,6 @@ def upload_blob_file(


def walk_folder(folder: str) -> list | None:
"""
Args:
folder (str): folder path
Returns:
list: list of file names contained in folder
"""
file_list = []
for dirname, _, fname in walk(folder):
for f in fname:
Expand Down Expand Up @@ -758,7 +738,7 @@ def add_job(
Args:
job_id (str): name of the job to run
pool_id (str): name of pool
batch_client (object): BatchClient object
batch_client (object): batch client object
task_retries (int): number of times to retry the task if it fails. Default 3.
mark_complete (bool): whether to mark the job complete after tasks finish running. Default False.
"""
Expand Down Expand Up @@ -970,7 +950,11 @@ def add_task_to_job(
return task_id


def monitor_tasks(job_id: str, timeout: int, batch_client: object):
def monitor_tasks(
job_id: str,
timeout: int,
batch_client: object
):
"""monitors tasks running in the job based on job ID
Args:
Expand Down Expand Up @@ -1330,7 +1314,7 @@ def get_mount_config(blob_config: list[str]):
"""takes blob configurations as input and combines them to create a mount configuration.
Args:
blob_config (list): usually from get_blob_config(). Usually one for input blob and one for output blob.
Blob configurations, usually from get_blob_config(). Usually one for input blob and one for output blob.
Returns:
list: mount configuration to used with get_pool_parameters.
Expand Down Expand Up @@ -1378,7 +1362,6 @@ def get_pool_parameters(
use_default_autoscale_formula (bool, optional)
max_autoscale_nodes (int): maximum number of nodes to use with autoscaling. Default 3.
task_slots_per_node (int): number of task slots per node. Default is 1.
availability_zones (bool): whether to use availability zones. Default False.
use_hpc_image (bool): whether to use a high performance compute image for each node. Default False.
Returns:
Expand Down Expand Up @@ -1430,13 +1413,12 @@ def get_pool_parameters(
"taskSlotsPerNode": task_slots_per_node,
"taskSchedulingPolicy": {"nodeFillType": "Spread"},
"deploymentConfiguration": get_deployment_config(
container_image_name=container_image_name,
container_registry_url=container_registry_url,
container_registry_server=container_registry_server,
config=config,
credential=credential,
availability_zones=availability_zones,
use_hpc_image=use_hpc_image,
container_image_name,
container_registry_url,
container_registry_server,
config,
credential,
use_hpc_image,
),
"networkConfiguration": get_network_config(config),
"scaleSettings": scale_settings,
Expand Down Expand Up @@ -1501,13 +1483,13 @@ def download_file(
src_path: str,
dest_path: str,
do_check: bool = True,
verbose: bool = False,
verbose=False,
) -> None:
"""
Download a file from Azure Blob storage
Args:
c_client (ContainerClient):
client (ContainerClient):
Instance of ContainerClient provided with the storage account
src_path (str):
Path within the container to the desired file (including filename)
Expand All @@ -1517,8 +1499,6 @@ def download_file(
Name of the storage container containing the file to be downloaded
do_check (bool):
Whether or not to do an existence check
verbose (bool):
Whether to be verbose in printing information
Raises:
ValueError:
Expand Down Expand Up @@ -1765,6 +1745,7 @@ def upload_docker_image(
registry_name (str): name of Azure Container Registry
repo_name (str): name of repo
tag (str): tag for the Docker container. Default is "latest". If None, a timestamp tag will be generated.
path_to_dockerfile (str): path to Dockerfile. Default is ./Dockerfile.
use_device_code (bool): whether to use the device code when authenticating. Default False.
Returns:
Expand Down Expand Up @@ -1821,10 +1802,10 @@ def check_pool_exists(
"""Check if a pool exists in Azure Batch
Args:
resource_group_name (str): Azure resource group name
account_name (str): Azure account name
pool_name (str): name of pool
batch_mgmt_client (object): instance of BatchManagementClient
resource_group_name (str):
account_name (str):
pool_name (str):
batch_mgmt_client (object):
Returns:
bool: whether the pool exists
Expand Down Expand Up @@ -1853,7 +1834,7 @@ def get_pool_info(
resource_group_name (str): name of resource group
account_name (str): name of account
pool_name (str): name of pool
batch_mgmt_client (object): instance of BatchManagementClient
batch_mgmt_client (object): instance of Batch Management Client
Returns:
dict: json with name, last_modified, creation_time, vm_size, and task_slots_per_node info
Expand Down Expand Up @@ -1885,7 +1866,7 @@ def get_pool_full_info(
resource_group_name (str): name of resource group
account_name (str): name of account
pool_name (str): name of pool
batch_mgmt_client (object): instance of BatchManagementClient
batch_mgmt_client (object): instance of Batch Management Client
Returns:
dict: dictionary with full pool information
Expand Down Expand Up @@ -1983,12 +1964,6 @@ def check_config_req(config: str):
def get_container_registry_client(
endpoint: str, credential: object, audience: str
):
"""
Args:
endpoint (str): the endpoint to the container registry
credential (object): a credential object
audience (str): audience for container registry client
"""
return ContainerRegistryClient(endpoint, credential, audience=audience)


Expand Down Expand Up @@ -2072,29 +2047,13 @@ def generate_autoscale_formula(max_nodes: int = 8) -> str:


def format_rel_path(rel_path: str) -> str:
"""
Formats a relative path into the right format for Azure services
Args:
rel_path (str): relative mount path
Returns:
str: formatted relative path
"""
if rel_path.startswith("/"):
rel_path = rel_path[1:]
logger.debug(f"path formatted to {rel_path}")
return rel_path


def get_timeout(_time: str) -> int:
"""
Args:
_time (str): formatted timeout string
Returns:
int: integer of timeout in minutes
"""
t = _time.split("PT")[-1]
if "H" in t:
if "M" in t:
Expand All @@ -2112,15 +2071,6 @@ def get_timeout(_time: str) -> int:
def list_blobs_flat(
container_name: str, blob_service_client: BlobServiceClient, verbose=True
):
"""
Args:
container_name (str): name of Blob container
blob_service_client (object): instance of BlobServiceClient
verbose (bool): whether to be verbose in printing files. Default True.
Returns:
list: list of blobs in Blob container
"""
logger.debug("Creating container client for getting Blob info.")
container_client = blob_service_client.get_container_client(
container=container_name
Expand Down Expand Up @@ -2176,12 +2126,6 @@ def get_log_level() -> int:
def delete_blob_snapshots(
blob_name: str, container_name: str, blob_service_client: object
):
"""
Args:
blob_name (str): name of blob
container_name (str): name of container
blob_service_client (object): instance of BlobServiceClient
"""
blob_client = blob_service_client.get_blob_client(
container=container_name, blob=blob_name
)
Expand All @@ -2192,12 +2136,6 @@ def delete_blob_snapshots(
def delete_blob_folder(
folder_path: str, container_name: str, blob_service_client: object
):
"""
Args:
folder_path (str): path to blob folder
container_name (str): name of Blob container
blob_service_client (object): instance of BlobServiceClient
"""
# create container client
c_client = blob_service_client.get_container_client(
container=container_name
Expand All @@ -2215,14 +2153,6 @@ def delete_blob_folder(


def format_extensions(extension):
"""
Formats extensions to include periods.
Args:
extension (str | list): string or list of strings of extensions. Can include a leading period but does not need to.
Returns:
list: list of formatted extensions
"""
if isinstance(extension, str):
extension = [extension]
ext = []
Expand All @@ -2245,7 +2175,6 @@ def check_autoscale_parameters(
"""Checks which arguments are incompatible with the provided scale mode
Args:
mode (str): pool mode, chosen from 'fixed' or 'autoscale'
dedicated_nodes (int): optional, the target number of dedicated compute nodes for the pool in fixed scaling mode. Defaults to None.
low_priority_nodes (int): optional, the target number of spot compute nodes for the pool in fixed scaling mode. Defaults to None.
node_deallocation_option (str): optional, determines what to do with a node and its running tasks after it has been selected for deallocation. Defaults to None.
Expand Down Expand Up @@ -2285,17 +2214,6 @@ def get_rel_mnt_path(
account_name: str,
batch_mgmt_client: object,
):
"""
Args:
blob_name (str): name of blob container
pool_name (str): name of pool
resource_group_name (str): name of resource group in Azure
account_name (str): name of account in Azure
batch_mgmt_object (object): instance of BatchManagementClient
Returns:
str: relative mount path for the blob and pool specified
"""
try:
pool_info = get_pool_full_info(
resource_group_name=resource_group_name,
Expand Down Expand Up @@ -2327,16 +2245,6 @@ def get_pool_mounts(
account_name: str,
batch_mgmt_client: object,
):
"""
Args:
pool_name (str): name of pool
resource_group_name (str): name of resource group in Azure
account_name (str): name of account in Azure
batch_mgmt_client (object): instance of BatchManagementClient
Returns:
list: list of mounts in specified pool
"""
try:
pool_info = get_pool_full_info(
resource_group_name=resource_group_name,
Expand Down

0 comments on commit b9333ec

Please sign in to comment.