Skip to content

Commit

Permalink
test: (new) fixed deferred cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexwang8 committed May 17, 2024
1 parent 57d55c4 commit 2a7ad6b
Showing 1 changed file with 9 additions and 43 deletions.
52 changes: 9 additions & 43 deletions gpt_bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ func TestReadTokenizerConfig(t *testing.T) {
assert.Equal(t, encoder.BosToken, Token(10989))
assert.Equal(t, encoder.PadToken, Token(5428))

// Clean up by removing the downloaded folder
// Finish the test, allow defered cleanup
fmt.Println("All Exists - Looks good.")
}

Expand All @@ -903,10 +903,10 @@ func TestModelDownload(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -921,11 +921,9 @@ func TestModelDownload(t *testing.T) {
fmt.Println("config.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("config.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for config.json")
}

Expand All @@ -935,11 +933,9 @@ func TestModelDownload(t *testing.T) {
fmt.Println("pytorch_model.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin")
}

Expand All @@ -949,11 +945,9 @@ func TestModelDownload(t *testing.T) {
fmt.Println("tokenizer.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("tokenizer.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for tokenizer.json")
}

Expand All @@ -963,16 +957,13 @@ func TestModelDownload(t *testing.T) {
fmt.Println("vocab.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("vocab.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for vocab.json")
}

// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
// Finish the test, allow defered cleanup
fmt.Println("All Exists - Looks good.")
}

Expand All @@ -989,10 +980,10 @@ func TestModelDownloadPythia(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -1007,11 +998,9 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("config.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("config.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for config.json")
}

Expand All @@ -1021,11 +1010,9 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("pytorch_model.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin")
}

Expand All @@ -1035,11 +1022,9 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("tokenizer.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("tokenizer.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for tokenizer.json")
}

Expand All @@ -1049,16 +1034,13 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("vocab.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("vocab.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for vocab.json")
}

// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
// Finish the test, allow defered cleanup
fmt.Println("All Exists - Looks good.")
}

Expand All @@ -1074,10 +1056,10 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -1092,11 +1074,9 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
fmt.Println("pytorch_model-00001-of-00002.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model-00001-of-00002.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model-00001-of-00002.bin")
}

Expand All @@ -1106,11 +1086,9 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
fmt.Println("pytorch_model-00002-of-00002.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model-00002-of-00002.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model-00002-of-00002.bin")
}

Expand All @@ -1120,16 +1098,13 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
fmt.Println("pytorch_model.bin.index.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin.index.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin.index.json")
}

// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
// Finish the test, allow defered cleanup
fmt.Println("All Exists - Looks good.")

}
Expand Down Expand Up @@ -1242,10 +1217,10 @@ func TestModelDownloadFairseq(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -1259,11 +1234,9 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("config.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("config.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for config.json")
}

Expand All @@ -1273,11 +1246,9 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("pytorch_model.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin")
}

Expand All @@ -1287,11 +1258,9 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("vocab.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("vocab.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for vocab.json")
}

Expand All @@ -1301,15 +1270,12 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("merges.txt exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("merges.txt does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for merges.txt")
}

// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
// Finish the test, allow defered cleanup
fmt.Println("All Exists - Looks good (Fairseq Download).")
}

0 comments on commit 2a7ad6b

Please sign in to comment.