Skip to content

Commit

Permalink
Fix tests in test_ner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thobson88 committed Oct 23, 2024
1 parent 559ffa8 commit 0e00bed
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_ner_local_train(tmp_path):

@pytest.mark.skip(reason="Needs large model file")
def test_ner_predict():
model_path = os.path.join(current_dir,"sample_files/resources/models/ner_test.model")
model_path = os.path.join(current_dir, "../resources/models/")
assert os.path.isdir(model_path) is True

myner = recogniser.Recogniser(
model="ner_test",
model="blb_lwm-ner-fine",
train_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_train.json"),
test_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_dev.json"),
base_model="Livingwithmachines/bert_1760_1900",
model_path=os.path.join(current_dir,"sample_files/resources/models/"),
model_path=model_path,
training_args={
"batch_size": 8,
"num_train_epochs": 10,
Expand All @@ -62,7 +62,7 @@ def test_ner_predict():
predictions = myner.ner_predict(sentence)
assert isinstance(predictions, list)
assert len(predictions) == 15
assert predictions[13] == {'entity': 'B-LOC', 'score': 0.7941257357597351, 'word': 'Sheffield', 'start': 74, 'end': 83}
assert predictions[13] == {'entity': 'B-LOC', 'score': 0.9996446371078491, 'word': 'Sheffield', 'start': 74, 'end': 83}

# Test that ner_predict() can handle hyphens
sentence = "- I grew up in Plymouth—Kingston."
Expand Down

0 comments on commit 0e00bed

Please sign in to comment.