From 0e00bed8f9621e051b55f8582b7d5e20b58436ae Mon Sep 17 00:00:00 2001 From: Tim Hobson Date: Wed, 23 Oct 2024 16:27:37 +0100 Subject: [PATCH] Fix tests in test_ner.py --- tests/test_ner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_ner.py b/tests/test_ner.py index 746e636a..28b079ca 100644 --- a/tests/test_ner.py +++ b/tests/test_ner.py @@ -36,15 +36,15 @@ def test_ner_local_train(tmp_path): @pytest.mark.skip(reason="Needs large model file") def test_ner_predict(): - model_path = os.path.join(current_dir,"sample_files/resources/models/ner_test.model") + model_path = os.path.join(current_dir, "../resources/models/") assert os.path.isdir(model_path) is True myner = recogniser.Recogniser( - model="ner_test", + model="blb_lwm-ner-fine", train_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_train.json"), test_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_dev.json"), base_model="Livingwithmachines/bert_1760_1900", - model_path=os.path.join(current_dir,"sample_files/resources/models/"), + model_path=model_path, training_args={ "batch_size": 8, "num_train_epochs": 10, @@ -62,7 +62,7 @@ def test_ner_predict(): predictions = myner.ner_predict(sentence) assert isinstance(predictions, list) assert len(predictions) == 15 - assert predictions[13] == {'entity': 'B-LOC', 'score': 0.7941257357597351, 'word': 'Sheffield', 'start': 74, 'end': 83} + assert predictions[13] == {'entity': 'B-LOC', 'score': 0.9996446371078491, 'word': 'Sheffield', 'start': 74, 'end': 83} # Test that ner_predict() can handle hyphens sentence = "- I grew up in Plymouth—Kingston."