diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index ea2f4dd9..00e48eed 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -24,20 +24,55 @@ import click from pathlib import Path +from difflib import get_close_matches + from .util import message logging.captureWarnings(True) logger = logging.getLogger('kraken') -def _get_field_list(name): +def _validate_script(script: str) -> str: + from htrmopo.util import _iso15924 + if script not in _iso15924: + return get_close_matches(script, _iso15924.keys()) + return script + + +def _validate_language(language: str) -> str: + from htrmopo.util import _iso639_3 + if language not in _iso639_3: + return get_close_matches(language, _iso639_3.keys()) + return language + + +def _validate_license(license: str) -> str: + from htrmopo.util import _licenses + if license not in _licenses: + return get_close_matches(license, _licenses.keys()) + return license + + +def _get_field_list(name, + validation_fn=lambda x: x, + required: bool = False): values = [] while True: - value = click.prompt(name, default=None) - if value is not None: - values.append(value) + value = click.prompt(name, default='') + if value: + if (cand := validation_fn(value)) == value: + values.append(value) + else: + message(f'Not a valid {name} value. Did you mean {cand}?') else: - break + if click.confirm(f'All `{name}` values added?'): + if required and not values: + message(f'`{name}` is a required field.') + continue + else: + break + else: + continue return values @@ -46,7 +81,7 @@ def _get_field_list(name): @click.option('-i', '--metadata', show_default=True, type=click.File(mode='r', lazy=True), help='Model card file for the model.') @click.option('-a', '--access-token', prompt=True, help='Zenodo access token') -@click.option('-d', '--doi', prompt=True, help='DOI of an existing record to update') +@click.option('-d', '--doi', help='DOI of an existing record to update') @click.option('-p', '--private/--public', default=False, help='Disables Zenodo ' 'community inclusion request. Allows upload of models that will not show ' 'up on `kraken list` output') @@ -56,15 +91,16 @@ def publish(ctx, metadata, access_token, doi, private, model): Publishes a model on the zenodo model repository. """ import json + import yaml import tempfile from htrmopo import publish_model, update_model - pub_fn = publish_model - from kraken.lib.vgsl import TorchVGSLModel from kraken.lib.progress import KrakenDownloadProgressBar + pub_fn = publish_model + _yaml_delim = r'(?:---|\+\+\+)' _yaml = r'(.*?)' _content = r'\s*(.+)$' @@ -77,27 +113,44 @@ def publish(ctx, metadata, access_token, doi, private, model): # construct metadata if none is given if metadata: frontmatter, content = _yaml_regex.match(metadata.read()).groups() + frontmatter = yaml.safe_load(frontmatter) else: frontmatter['summary'] = click.prompt('summary') content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here') creators = [] + message('To stop adding authors, leave the author name field empty.') while True: - author = click.prompt('author', default=None) - affiliation = click.prompt('affiliation', default=None) - orcid = click.prompt('orcid', default=None) - if author is not None: - creators.append({'author': author}) + author = click.prompt('author name', default='') + if author: + creators.append({'name': author}) else: - break + if click.confirm('All authors added?'): + break + else: + continue + affiliation = click.prompt('affiliation', default='') + orcid = click.prompt('orcid', default='') if affiliation is not None: creators[-1]['affiliation'] = affiliation if orcid is not None: creators[-1]['orcid'] = orcid + if not creators: + raise click.UsageError('The `authors` field is obligatory. Aborting') + frontmatter['authors'] = creators - frontmatter['license'] = click.prompt('license') - frontmatter['language'] = _get_field_list('language') - frontmatter['script'] = _get_field_list('script') + while True: + license = click.prompt('license') + if (lic := _validate_license(license)) == license: + frontmatter['license'] = license + break + else: + message(f'Not a valid license identifer. Did you mean {lic}?') + + message('To stop adding values to the following fields, enter an empty field.') + + frontmatter['language'] = _get_field_list('language', _validate_language, required=True) + frontmatter['script'] = _get_field_list('script', _validate_script, required=True) if len(tags := _get_field_list('tag')): frontmatter['tags'] = tags + ['kraken_pytorch'] @@ -108,30 +161,33 @@ def publish(ctx, metadata, access_token, doi, private, model): # take last metrics field, falling back to accuracy field in model metadata metrics = {} - if 'metrics' in nn.user_metadata and nn.user_metadata['metrics']: - metrics['cer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_accuracy'] - metrics['wer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_word_accuracy'] - elif 'accuracy' in nn.user_metadata and nn.user_metadata['accuracy']: - metrics['cer'] = 100 - nn.user_metadata['accuracy'] + if nn.user_metadata.get('metrics', None) is not None: + if (val_accuracy := nn.user_metadata['metrics'][-1][1].get('val_accuracy', None)) is not None: + metrics['cer'] = 100 - (val_accuracy * 100) + if (val_word_accuracy := nn.user_metadata['metrics'][-1][1].get('val_word_accuracy', None)) is not None: + metrics['wer'] = 100 - (val_word_accuracy * 100) + elif (accuracy := nn.user_metadata.get('accuracy', None)) is not None: + metrics['cer'] = 100 - accuracy frontmatter['metrics'] = metrics software_hints = ['kind=vgsl'] # some recognition-specific software hints if nn.model_type == 'recognition': - software_hints.append([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', 'legacy_polygons={nn.user_metadata["legacy_polygons"]}']) + software_hints.extend([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', f'legacy_polygons={nn.user_metadata["legacy_polygons"]}']) frontmatter['software_hints'] = software_hints frontmatter['software_name'] = 'kraken' + frontmatter['model_type'] = [nn.model_type] # build temporary directory with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress: upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False) - model = Path(model) + model = Path(model).resolve() tmpdir = Path(tmpdir) - (tmpdir / model.name).symlink_to(model) - # v0 metadata only supports recognition models + (tmpdir / model.name).resolve().symlink_to(model) if nn.model_type == 'recognition': + # v0 metadata only supports recognition models v0_metadata = { 'summary': frontmatter['summary'], 'description': content, @@ -145,7 +201,7 @@ def publish(ctx, metadata, access_token, doi, private, model): with open(tmpdir / 'metadata.json', 'w') as fo: json.dump(v0_metadata, fo) kwargs = {'model': tmpdir, - 'model_card': f'---\n{frontmatter}---\n{content}', + 'model_card': f'---\n{yaml.dump(frontmatter)}---\n{content}', 'access_token': access_token, 'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance), 'private': private} diff --git a/kraken/repo.py b/kraken/repo.py index f283deb7..5168da3c 100644 --- a/kraken/repo.py +++ b/kraken/repo.py @@ -1,5 +1,5 @@ # -# Copyright 2015 Benjamin Kiessling +# Copyright 2025 Benjamin Kiessling # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,15 +17,12 @@ ~~~~~~~~~~~ Wrappers around the htrmopo reference implementation implementing -kraken-specific filtering. +kraken-specific filtering for repository querying operations. """ -import logging -import warnings -from pathlib import Path from collections import defaultdict -from typing import IO, Any, Dict, List, Union, cast, Optional, TypeVar, Iterable, Literal - from collections.abc import Callable +from typing import Any, Dict, Optional, TypeVar, Literal + from htrmopo import get_description as mopo_get_description from htrmopo import get_listing as mopo_get_listing