Skip to content

build example/main.cpp as shared library and intercept token printing using FFI #8339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This example program allows you to use various LLaMA language models in an easy
6. [Generation Flags](#generation-flags)
7. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options)
8. [Additional Options](#additional-options)
9. [Shared Library](#shared-library)

## Quick Start

Expand Down Expand Up @@ -314,3 +315,82 @@ These options provide extra functionality and customization when running the LLa
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.

- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.

## Shared Library

To build `llama-cli` as a shared library, run the following command from the root directory of the repository:

```bash
CXXFLAGS="-DSHARED_LIB" LDFLAGS="-shared -o libllama-cli.so" make llama-cli
```

You will receive the function `llama_cli_main`, which can be invoked via FFI with the standard options available to `llama-cli`:

```c
int llama_cli_main(int argc, char ** argv);
```

To enhance the management of custom file descriptors for STDOUT and STDERR, and to intercept token printing, we provide four functions:

```c
void llama_set_stdout(FILE* f);
void llama_set_stderr(FILE* f);
void llama_set_fprintf(int (*func)(FILE*, const char*, ...));
void llama_set_fflush(int (*func)(FILE*));
```

This is particularly beneficial if you need to use `libllama-cli.so` through FFI in other programming languages without altering the default STDOUT and STDERR file descriptors.

Here's a Python example that independently handles printing tokens without relying on STDOUT and STDERR:

```python
from ctypes import *

#
# open shared library
#
lib = CDLL('./libllama-cli.so')
lib.llama_cli_main.argtypes = [c_int, POINTER(c_char_p)]
lib.llama_cli_main.restype = c_int

#
# redefine fprintf and fflush
#
@CFUNCTYPE(c_int, c_void_p, c_char_p, c_char_p)
def fprintf(file_obj, fmt, *args):
content = fmt.decode('utf-8') % tuple(arg.decode('utf-8') for arg in args)
print(content, flush=True, end='')
size = len(content)
return size


@CFUNCTYPE(c_int, c_void_p)
def fflush(file_obj):
print(flush=True, end='')
return 0


lib.llama_set_fprintf(fprintf)
lib.llama_set_fflush(fflush)

#
# generate and print token by token
#
argv: list[bytes] = [
b'llama-cli',
b'-m',
b'models/7B/ggml-model.bin',
b'--no-display-prompt',
b'--simple-io',
b'--log-disable',
b'-p',
b'What is cosmos?',
]

argc = len(argv)
argv = (c_char_p * argc)(*argv)
res = lib.llama_cli_main(argc, argv)
assert res == 0
```

You can capture generated tokens in the Python implementation of `fprintf` function without actually printing them, if necessary.
62 changes: 55 additions & 7 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,39 @@ static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool need_insert_eot = false;
static FILE *llama_stdout = stdout;
static FILE *llama_stderr = stderr;
static int (*llama_fprintf)(FILE*, const char*, ...) = fprintf;
static int (*llama_fflush)(FILE*) = fflush;

#ifdef __cplusplus
extern "C" {
#endif

void llama_set_stdout(FILE* f);
void llama_set_stderr(FILE* f);
void llama_set_fprintf(int (*func)(FILE*, const char*, ...));
void llama_set_fflush(int (*func)(FILE*));

void llama_set_stdout(FILE* f) {
llama_stdout = f;
}

void llama_set_stderr(FILE* f) {
llama_stderr = f;
}

void llama_set_fprintf(int (*func)(FILE*, const char*, ...)) {
llama_fprintf = func;
}

void llama_set_fflush(int (*func)(FILE*)) {
llama_fflush = func;
}

#ifdef __cplusplus
}
#endif

static bool file_exists(const std::string & path) {
std::ifstream f(path.c_str());
Expand All @@ -65,7 +98,7 @@ static void write_logfile(

const bool success = fs_create_directory_with_parents(params.logdir);
if (!success) {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
llama_fprintf(llama_stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
return;
}
Expand All @@ -74,7 +107,7 @@ static void write_logfile(
FILE * logfile = fopen(logfile_path.c_str(), "w");

if (logfile == NULL) {
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
llama_fprintf(llama_stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
return;
}

Expand Down Expand Up @@ -127,7 +160,18 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector<l
return formatted;
}

int main(int argc, char ** argv) {
#ifdef __cplusplus
extern "C" {
#endif

#ifdef SHARED_LIB
int llama_cli_main(int argc, char ** argv);

int llama_cli_main(int argc, char ** argv)
#else
int main(int argc, char ** argv)
#endif
{
gpt_params params;
g_params = &params;

Expand Down Expand Up @@ -524,7 +568,7 @@ int main(int argc, char ** argv) {

struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
llama_fprintf(llama_stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}

Expand Down Expand Up @@ -561,7 +605,7 @@ int main(int argc, char ** argv) {
console::set_display(console::error);
printf("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
console::set_display(console::reset);
fflush(stdout);
llama_fflush(llama_stdout);
}

if (ga_n == 1) {
Expand Down Expand Up @@ -761,7 +805,7 @@ int main(int argc, char ** argv) {
const std::string token_str = llama_token_to_piece(ctx, id, params.special);

// Console/Stream Output
fprintf(stdout, "%s", token_str.c_str());
llama_fprintf(llama_stdout, "%s", token_str.c_str());

// Record Displayed Tokens To Log
// Note: Generated tokens are created one by one hence this check
Expand All @@ -774,7 +818,7 @@ int main(int argc, char ** argv) {
output_ss << token_str;
}

fflush(stdout);
llama_fflush(llama_stdout);
}
}

Expand Down Expand Up @@ -986,3 +1030,7 @@ int main(int argc, char ** argv) {

return 0;
}

#ifdef __cplusplus
}
#endif
Loading