Skip to content

Commit

Permalink
wip for new model repository
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 3, 2025
1 parent 08db477 commit ffb1ef8
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 359 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ jobs:
python -m build --sdist --wheel --outdir dist/ .
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- name: Upload PyPI artifacts to GH storage
uses: actions/upload-artifact@v3
with:
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ dependencies:
- setuptools>=36.6.0,<70.0.0
- pip:
- coremltools~=8.1
- htrmopo
- file:.
1 change: 1 addition & 0 deletions environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ dependencies:
- setuptools>=36.6.0,<70.0.0
- pip:
- coremltools~=8.1
- htrmopo
- file:.
169 changes: 107 additions & 62 deletions kraken/ketos/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,94 +18,139 @@
Command line driver for publishing models to the model repository.
"""
import re
import logging
import os

import click

from pathlib import Path
from .util import message

logging.captureWarnings(True)
logger = logging.getLogger('kraken')


def _get_field_list(name):
values = []
while True:
value = click.prompt(name, default=None)
if value is not None:
values.append(value)
else:
break
return values


@click.command('publish')
@click.pass_context
@click.option('-i', '--metadata', show_default=True,
type=click.File(mode='r', lazy=True), help='Metadata for the '
'model. Will be prompted from the user if not given')
type=click.File(mode='r', lazy=True), help='Model card file for the model.')
@click.option('-a', '--access-token', prompt=True, help='Zenodo access token')
@click.option('-d', '--doi', prompt=True, help='DOI of an existing record to update')
@click.option('-p', '--private/--public', default=False, help='Disables Zenodo '
'community inclusion request. Allows upload of models that will not show '
'up on `kraken list` output')
@click.argument('model', nargs=1, type=click.Path(exists=False, readable=True, dir_okay=False))
def publish(ctx, metadata, access_token, private, model):
def publish(ctx, metadata, access_token, doi, private, model):
"""
Publishes a model on the zenodo model repository.
"""
import json
import tempfile

from htrmopo import publish_model, update_model

from importlib import resources
from jsonschema import validate
from jsonschema.exceptions import ValidationError
pub_fn = publish_model

from kraken import repo
from kraken.lib import models
from kraken.lib.vgsl import TorchVGSLModel
from kraken.lib.progress import KrakenDownloadProgressBar

ref = resources.files('kraken').joinpath('metadata.schema.json')
with open(ref, 'rb') as fp:
schema = json.load(fp)

nn = models.load_any(model)

if not metadata:
author = click.prompt('author')
affiliation = click.prompt('affiliation')
summary = click.prompt('summary')
description = click.edit('Write long form description (training data, transcription standards) of the model here')
accuracy_default = None
# take last accuracy measurement in model metadata
if 'accuracy' in nn.nn.user_metadata and nn.nn.user_metadata['accuracy']:
accuracy_default = nn.nn.user_metadata['accuracy'][-1][1] * 100
accuracy = click.prompt('accuracy on test set', type=float, default=accuracy_default)
script = [
click.prompt(
'script',
type=click.Choice(
sorted(
schema['properties']['script']['items']['enum'])),
show_choices=True)]
license = click.prompt(
'license',
type=click.Choice(
sorted(
schema['properties']['license']['enum'])),
show_choices=True)
metadata = {
'authors': [{'name': author, 'affiliation': affiliation}],
'summary': summary,
'description': description,
'accuracy': accuracy,
'license': license,
'script': script,
'name': os.path.basename(model),
'graphemes': ['a']
}
while True:
try:
validate(metadata, schema)
except ValidationError as e:
message(e.message)
metadata[e.path[-1]] = click.prompt(e.path[-1], type=float if e.schema['type'] == 'number' else str)
continue
break
_yaml_delim = r'(?:---|\+\+\+)'
_yaml = r'(.*?)'
_content = r'\s*(.+)$'
_re_pattern = r'^\s*' + _yaml_delim + _yaml + _yaml_delim + _content
_yaml_regex = re.compile(_re_pattern, re.S | re.M)

nn = TorchVGSLModel.load_model(model)

frontmatter = {}
# construct metadata if none is given
if metadata:
frontmatter, content = _yaml_regex.match(metadata.read()).groups()
else:
metadata = json.load(metadata)
validate(metadata, schema)
metadata['graphemes'] = [char for char in ''.join(nn.codec.c2l.keys())]
with KrakenDownloadProgressBar() as progress:
frontmatter['summary'] = click.prompt('summary')
content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here')

creators = []
while True:
author = click.prompt('author', default=None)
affiliation = click.prompt('affiliation', default=None)
orcid = click.prompt('orcid', default=None)
if author is not None:
creators.append({'author': author})
else:
break
if affiliation is not None:
creators[-1]['affiliation'] = affiliation
if orcid is not None:
creators[-1]['orcid'] = orcid
frontmatter['authors'] = creators
frontmatter['license'] = click.prompt('license')
frontmatter['language'] = _get_field_list('language')
frontmatter['script'] = _get_field_list('script')

if len(tags := _get_field_list('tag')):
frontmatter['tags'] = tags + ['kraken_pytorch']
if len(datasets := _get_field_list('dataset URL')):
frontmatter['datasets'] = datasets
if len(base_model := _get_field_list('base model URL')):
frontmatter['base_model'] = base_model

# take last metrics field, falling back to accuracy field in model metadata
metrics = {}
if 'metrics' in nn.user_metadata and nn.user_metadata['metrics']:
metrics['cer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_accuracy']
metrics['wer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_word_accuracy']
elif 'accuracy' in nn.user_metadata and nn.user_metadata['accuracy']:
metrics['cer'] = 100 - nn.user_metadata['accuracy']
frontmatter['metrics'] = metrics
software_hints = ['kind=vgsl']

# some recognition-specific software hints
if nn.model_type == 'recognition':
software_hints.append([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', 'legacy_polygons={nn.user_metadata["legacy_polygons"]}'])
frontmatter['software_hints'] = software_hints

frontmatter['software_name'] = 'kraken'

# build temporary directory
with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress:
upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False)
oid = repo.publish_model(model, metadata, access_token, lambda total, advance: progress.update(upload_task, total=total, advance=advance), private)
message('model PID: {}'.format(oid))

model = Path(model)
tmpdir = Path(tmpdir)
(tmpdir / model.name).symlink_to(model)
# v0 metadata only supports recognition models
if nn.model_type == 'recognition':
v0_metadata = {
'summary': frontmatter['summary'],
'description': content,
'license': frontmatter['license'],
'script': frontmatter['script'],
'name': model.name,
'graphemes': [char for char in ''.join(nn.codec.c2l.keys())]
}
if frontmatter['metrics']:
v0_metadata['accuracy'] = 100 - metrics['cer']
with open(tmpdir / 'metadata.json', 'w') as fo:
json.dump(v0_metadata, fo)
kwargs = {'model': tmpdir,
'model_card': f'---\n{frontmatter}---\n{content}',
'access_token': access_token,
'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance),
'private': private}
if doi:
pub_fn = update_model
kwargs['model_id'] = doi
oid = pub_fn(**kwargs)
message(f'model PID: {oid}')
136 changes: 116 additions & 20 deletions kraken/kraken.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
import click
from PIL import Image
from importlib import resources

from rich import print
from rich.tree import Tree
from rich.table import Table
from rich.console import Group
from rich.traceback import install
from rich.logging import RichHandler
from rich.markdown import Markdown
from rich.progress import Progress

from kraken.lib import log

Expand Down Expand Up @@ -677,29 +685,90 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction):

@cli.command('show')
@click.pass_context
@click.option('-V', '--metadata-version',
default='highest',
type=click.Choice(['v0', 'v1', 'highest']),
help='Version of metadata to fetch if multiple exist in repository.')
@click.argument('model_id')
def show(ctx, model_id):
def show(ctx, metadata_version, model_id):
"""
Retrieves model metadata from the repository.
"""
from kraken import repo
from htrmopo import get_description
from htrmopo.util import iso15924_to_name, iso639_3_to_name
from kraken.lib.util import is_printable, make_printable

desc = repo.get_description(model_id)
def _render_creators(creators):
o = []
for creator in creators:
c_text = creator['name']
if (orcid := creator.get('orcid', None)) is not None:
c_text += f' ({orcid})'
if (affiliation := creator.get('affiliation', None)) is not None:
c_text += f' ({affiliation})'
o.append(c_text)
return o

chars = []
combining = []
for char in sorted(desc['graphemes']):
if not is_printable(char):
combining.append(make_printable(char))
else:
chars.append(char)
message(
'name: {}\n\n{}\n\n{}\nscripts: {}\nalphabet: {} {}\naccuracy: {:.2f}%\nlicense: {}\nauthor(s): {}\ndate: {}'.format(
model_id, desc['summary'], desc['description'], ' '.join(
desc['script']), ''.join(chars), ', '.join(combining), desc['accuracy'], desc['license']['id'], '; '.join(
x['name'] for x in desc['creators']), desc['publication_date']))
ctx.exit(0)
def _render_metrics(metrics):
return [f'{k}: {v:.2f}' for k, v in metrics.items()]

if metadata_version == 'highest':
metadata_version = None

try:
desc = get_description(model_id, version=metadata_version)
except ValueError as e:
logger.error(e)
ctx.exit(1)

if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords:
logger.error('Record exists but is not a kraken-compatible model')
ctx.exit(1)

if desc.version == 'v0':
chars = []
combining = []
for char in sorted(desc.graphemes):
if not is_printable(char):
combining.append(make_printable(char))
else:
chars.append(char)

table = Table(title=desc.summary, show_header=False)
table.add_column('key', justify="left", no_wrap=True)
table.add_column('value', justify="left", no_wrap=False)
table.add_row('DOI', desc.doi)
table.add_row('concept DOI', desc.concept_doi)
table.add_row('publication date', desc.publication_date.isoformat())
table.add_row('model type', Group(*desc.model_type))
table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script]))
table.add_row('alphabet', Group(' '.join(chars), ', '.join(combining)))
table.add_row('keywords', Group(*desc.keywords))
table.add_row('metrics', Group(*_render_metrics(desc.metrics)))
table.add_row('license', desc.license)
table.add_row('creators', Group(*_render_creators(desc.creators)))
table.add_row('description', desc.description)
elif desc.version == 'v1':
table = Table(title=desc.summary, show_header=False)
table.add_column('key', justify="left", no_wrap=True)
table.add_column('value', justify="left", no_wrap=False)
table.add_row('DOI', desc.doi)
table.add_row('concept DOI', desc.concept_doi)
table.add_row('publication date', desc.publication_date.isoformat())
table.add_row('model type', Group(*desc.model_type))
table.add_row('language', Group(*[iso639_3_to_name(x) for x in desc.language]))
table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script]))
table.add_row('keywords', Group(*desc.keywords))
table.add_row('datasets', Group(*desc.datasets))
table.add_row('metrics', Group(*_render_metrics(desc.metrics)))
table.add_row('base model', Group(*desc.base_model))
table.add_row('software', desc.software_name)
table.add_row('software_hints', Group(*desc.software_hints))
table.add_row('license', desc.license)
table.add_row('creators', Group(*_render_creators(desc.creators)))
table.add_row('description', Markdown(desc.description))

print(table)


@cli.command('list')
Expand All @@ -708,14 +777,41 @@ def list_models(ctx):
"""
Lists models in the repository.
"""
from kraken import repo
from htrmopo import get_listing
from collections import defaultdict
from kraken.lib.progress import KrakenProgressBar

with KrakenProgressBar() as progress:
download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False)
model_list = repo.get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance))
for id, metadata in model_list.items():
message('{} ({}) - {}'.format(id, ', '.join(metadata['type']), metadata['summary']))
repository = get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance))
# aggregate models under their concept DOI
concepts = defaultdict(list)
for item in repository.values():
# both got the same DOI information
record = item['v0'] if item['v0'] else item['v1']
concepts[record.concept_doi].append(record.doi)

table = Table(show_header=True)
table.add_column('DOI', justify="left", no_wrap=True)
table.add_column('summary', justify="left", no_wrap=False)
table.add_column('model type', justify="left", no_wrap=False)
table.add_column('keywords', justify="left", no_wrap=False)

for k, v in concepts.items():
records = [repository[x]['v1'] if 'v1' in repository[x] else repository[x]['v0'] for x in v]
records = filter(lambda record: getattr(record, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in record.keywords, records)
records = sorted(records, key=lambda x: x.publication_date, reverse=True)
if not len(records):
continue

t = Tree(k)
[t.add(x.doi) for x in records]
table.add_row(t,
Group(*[''] + [x.summary for x in records]),
Group(*[''] + ['; '.join(x.model_type) for x in records]),
Group(*[''] + ['; '.join(x.keywords) for x in records]))

print(table)
ctx.exit(0)


Expand Down
Loading

0 comments on commit ffb1ef8

Please sign in to comment.