Skip to content

Commit

Permalink
Factor out htrmopo calls to include filters
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 4, 2025
1 parent ffb1ef8 commit 716c520
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 31 deletions.
56 changes: 25 additions & 31 deletions kraken/kraken.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,8 @@ def show(ctx, metadata_version, model_id):
"""
Retrieves model metadata from the repository.
"""
from htrmopo import get_description
from htrmopo.util import iso15924_to_name, iso639_3_to_name
from kraken.repo import get_description
from kraken.lib.util import is_printable, make_printable

def _render_creators(creators):
Expand All @@ -716,15 +716,13 @@ def _render_metrics(metrics):
metadata_version = None

try:
desc = get_description(model_id, version=metadata_version)
desc = get_description(model_id,
version=metadata_version,
filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords)
except ValueError as e:
logger.error(e)
ctx.exit(1)

if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords:
logger.error('Record exists but is not a kraken-compatible model')
ctx.exit(1)

if desc.version == 'v0':
chars = []
combining = []
Expand Down Expand Up @@ -777,33 +775,21 @@ def list_models(ctx):
"""
Lists models in the repository.
"""
from htrmopo import get_listing
from collections import defaultdict
from kraken.repo import get_listing
from kraken.lib.progress import KrakenProgressBar

with KrakenProgressBar() as progress:
download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False)
repository = get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance))
# aggregate models under their concept DOI
concepts = defaultdict(list)
for item in repository.values():
# both got the same DOI information
record = item['v0'] if item['v0'] else item['v1']
concepts[record.concept_doi].append(record.doi)
repository = get_listing(callback=lambda total, advance: progress.update(download_task, total=total, advance=advance),
filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords)

table = Table(show_header=True)
table.add_column('DOI', justify="left", no_wrap=True)
table.add_column('summary', justify="left", no_wrap=False)
table.add_column('model type', justify="left", no_wrap=False)
table.add_column('keywords', justify="left", no_wrap=False)

for k, v in concepts.items():
records = [repository[x]['v1'] if 'v1' in repository[x] else repository[x]['v0'] for x in v]
records = filter(lambda record: getattr(record, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in record.keywords, records)
records = sorted(records, key=lambda x: x.publication_date, reverse=True)
if not len(records):
continue

for k, records in repository.items():
t = Tree(k)
[t.add(x.doi) for x in records]
table.add_row(t,
Expand All @@ -812,7 +798,6 @@ def list_models(ctx):
Group(*[''] + ['; '.join(x.keywords) for x in records]))

print(table)
ctx.exit(0)


@cli.command('get')
Expand All @@ -822,20 +807,29 @@ def get(ctx, model_id):
"""
Retrieves a model from the repository.
"""
from kraken import repo
import glob

from htrmopo import get_model, get_description

from kraken.lib.progress import KrakenDownloadProgressBar

try:
os.makedirs(click.get_app_dir(APP_NAME))
except OSError:
pass
desc = get_description(model_id)
except ValueError as e:
logger.error(e)
ctx.exit(1)

print(desc)
if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords:
logger.error('Record exists but is not a kraken-compatible model')
ctx.exit(1)

with KrakenDownloadProgressBar() as progress:
download_task = progress.add_task('Processing', total=0, visible=True if not ctx.meta['verbose'] else False)
filename = repo.get_model(model_id, click.get_app_dir(APP_NAME),
lambda total, advance: progress.update(download_task, total=total, advance=advance))
message(f'Model name: {filename}')
ctx.exit(0)
model_dir = get_model(model_id,
lambda total, advance: progress.update(download_task, total=total, advance=advance))
model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iter_dir()))
message(f'Model dir: {model_dir} (model files: {model_candidates})')


if __name__ == '__main__':
Expand Down
87 changes: 87 additions & 0 deletions kraken/repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Copyright 2015 Benjamin Kiessling
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
kraken.repo
~~~~~~~~~~~
Wrappers around the htrmopo reference implementation implementing
kraken-specific filtering.
"""
import logging
import warnings
from pathlib import Path
from collections import defaultdict
from typing import IO, Any, Dict, List, Union, cast, Optional, TypeVar, Iterable, Literal

from collections.abc import Callable

from htrmopo import get_description as mopo_get_description
from htrmopo import get_listing as mopo_get_listing
from htrmopo.record import v0RepositoryRecord, v1RepositoryRecord


_v0_or_v1_Record = TypeVar('_v0_or_v1_Record', v0RepositoryRecord, v1RepositoryRecord)


def get_description(model_id: str,
callback: Callable[..., Any] = lambda: None,
version: Optional[Literal['v0', 'v1']] = None,
filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> _v0_or_v1_Record:
"""
Filters the output of htrmopo.get_description with a custom function.
Args:
model_id: model DOI
callback: Progress callback
version:
filter_fn: Function called to filter the retrieved record.
"""
desc = mopo_get_description(model_id, callback, version)
if not filter_fn(desc):
raise ValueError(f'Record {model_id} exists but is not a valid kraken record')
return desc


def get_listing(callback: Callable[[int, int], Any] = lambda total, advance: None,
from_date: Optional[str] = None,
filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> Dict[str, Dict[str, _v0_or_v1_Record]]:
"""
Returns a filtered representation of the model repository grouped by
concept DOI.
Args:
callback: Progress callback
from_data:
filter_fn: Function called for each record object
Returns:
A dictionary mapping group DOIs to one record object per deposit. The
record of the highest available schema version is retained.
"""
repository = mopo_get_listing(callback, from_date)
# aggregate models under their concept DOI
concepts = defaultdict(list)
for item in repository.values():
# filter records here
item = {k: v for k, v in item.items() if filter_fn(v)}
# both got the same DOI information
record = item.get('v1', item.get('v0', None))
if record is not None:
concepts[record.concept_doi].append(record)

for k, v in concepts.items():
concepts[k] = sorted(v, key=lambda x: x.publication_date, reverse=True)

return concepts

0 comments on commit 716c520

Please sign in to comment.