Skip to content

Commit c68f8a7

Browse files
ewan0x79frankfliu
andcommitted
[tokenizer] add optional tokenizerPath Prior to modelPath (#3120)
* [tokenizer] add optional tokenizerPath Prior to modelPath --------- Co-authored-by: Frank Liu <[email protected]>
1 parent 56767b9 commit c68f8a7

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.io.InputStream;
3131
import java.nio.file.Files;
3232
import java.nio.file.Path;
33+
import java.nio.file.Paths;
3334
import java.util.Arrays;
3435
import java.util.List;
3536
import java.util.Locale;
@@ -686,7 +687,6 @@ static PaddingStrategy fromValue(String value) {
686687
/** The builder for creating huggingface tokenizer. */
687688
public static final class Builder {
688689

689-
private Path tokenizerPath;
690690
private NDManager manager;
691691
private Map<String, String> options;
692692

@@ -724,7 +724,7 @@ public Builder optTokenizerName(String tokenizerName) {
724724
* @return this builder
725725
*/
726726
public Builder optTokenizerPath(Path tokenizerPath) {
727-
this.tokenizerPath = tokenizerPath;
727+
options.putIfAbsent("tokenizerPath", tokenizerPath.toString());
728728
return this;
729729
}
730730

@@ -894,9 +894,11 @@ public HuggingFaceTokenizer build() throws IOException {
894894
if (tokenizerName != null) {
895895
return managed(HuggingFaceTokenizer.newInstance(tokenizerName, options));
896896
}
897-
if (tokenizerPath == null) {
897+
String path = options.get("tokenizerPath");
898+
if (path == null) {
898899
throw new IllegalArgumentException("Missing tokenizer path.");
899900
}
901+
Path tokenizerPath = Paths.get(path);
900902
if (Files.isDirectory(tokenizerPath)) {
901903
Path tokenizerFile = tokenizerPath.resolve("tokenizer.json");
902904
if (Files.exists(tokenizerFile)) {

extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public void testCrossEncoderTranslator()
6464
.optBlock(block)
6565
.optEngine("PyTorch")
6666
.optArgument("tokenizer", "bert-base-cased")
67+
.optArgument("tokenizerPath", modelDir)
6768
.optOption("hasParameter", "false")
6869
.optTranslatorFactory(new CrossEncoderTranslatorFactory())
6970
.build();

0 commit comments

Comments
 (0)