Skip to content

feat(dynamo-run): Basic routing choice #524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions launch/dynamo-run/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ use std::collections::HashMap;
use std::path::PathBuf;
use std::str::FromStr;

use clap::ValueEnum;
use dynamo_runtime::component::RouterMode as RuntimeRouterMode;

/// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
Expand Down Expand Up @@ -92,6 +95,13 @@ pub struct Flags {
#[arg(long)]
pub leader_addr: Option<String>,

/// If using `out=dyn://..` with multiple backends, this says how to route the requests.
///
/// Mostly interesting for KV-aware routing.
/// Defaults to RouterMode::Random
#[arg(long, default_value = "random")]
pub router_mode: RouterMode,

/// Internal use only.
// Start the python vllm engine sub-process.
#[arg(long, hide = true, default_value = "false")]
Expand Down Expand Up @@ -198,3 +208,29 @@ fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> {
gpu_id: nums[2],
})
}

#[derive(Default, PartialEq, Eq, ValueEnum, Clone, Debug)]
pub enum RouterMode {
#[default]
Random,
#[value(name = "round-robin")]
RoundRobin,
#[value(name = "kv")]
KV,
}

impl RouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV
}
}

impl From<RouterMode> for RuntimeRouterMode {
fn from(r: RouterMode) -> RuntimeRouterMode {
match r {
RouterMode::RoundRobin => RuntimeRouterMode::RoundRobin,
RouterMode::KV => todo!("KV not implemented yet"),
_ => RuntimeRouterMode::Random,
}
}
}
7 changes: 4 additions & 3 deletions launch/dynamo-run/src/input/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

use crate::input::common;
use crate::EngineConfig;
use crate::{EngineConfig, Flags};

/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
Expand Down Expand Up @@ -64,11 +64,12 @@ struct Entry {

pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken,
flags: Flags,
maybe_card: Option<ModelDeploymentCard>,
input_jsonl: PathBuf,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
// Check if the path exists and is a directory
if !input_jsonl.exists() || !input_jsonl.is_file() {
anyhow::bail!(
Expand All @@ -78,7 +79,7 @@ pub async fn run(
}

let (service_name, engine, _inspect_template) =
common::prepare_engine(runtime.clone(), engine_config).await?;
common::prepare_engine(runtime, flags, engine_config).await?;
let service_name_ref = Arc::new(service_name);

let pre_processor = if let Some(card) = maybe_card {
Expand Down
24 changes: 16 additions & 8 deletions launch/dynamo-run/src/input/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::EngineConfig;
use crate::{flags::RouterMode, EngineConfig, Flags};
use dynamo_llm::{
backend::Backend,
preprocessor::OpenAIPreprocessor,
Expand All @@ -34,21 +34,29 @@ use std::sync::Arc;
/// Turns an EngineConfig into an OpenAIChatCompletionsStreamingEngine.
pub async fn prepare_engine(
runtime: Runtime,
flags: Flags,
engine_config: EngineConfig,
) -> anyhow::Result<(String, OpenAIChatCompletionsStreamingEngine, bool)> {
match engine_config {
EngineConfig::Dynamic(endpoint_id) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;

let endpoint = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?
.endpoint(endpoint_id.name);
.namespace(endpoint_id.namespace.clone())?
.component(endpoint_id.component.clone())?
.endpoint(endpoint_id.name.clone());

let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
let mut client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;

match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => {
client.set_router_mode(flags.router_mode.into());
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
}
RouterMode::KV => todo!(),
}

// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
Expand Down
15 changes: 8 additions & 7 deletions launch/dynamo-run/src/input/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@ use dynamo_llm::{
use dynamo_runtime::pipeline::{
network::Ingress, ManyOut, Operator, SegmentSource, ServiceBackend, SingleIn, Source,
};
use dynamo_runtime::{protocols::Endpoint, DistributedRuntime, Runtime};
use dynamo_runtime::{protocols::Endpoint, DistributedRuntime};

use crate::EngineConfig;

pub async fn run(
runtime: Runtime,
distributed_runtime: DistributedRuntime,
path: String,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
// This will attempt to connect to NATS and etcd
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;

let cancel_token = runtime.primary_token().clone();
let cancel_token = distributed_runtime.primary_token().clone();
let endpoint_id: Endpoint = path.parse()?;

let etcd_client = distributed_runtime.etcd_client();

let (ingress, service_name) = match engine_config {
EngineConfig::StaticFull {
service_name,
Expand Down Expand Up @@ -85,7 +86,7 @@ pub async fn run(
model_type: ModelType::Chat,
};

let component = distributed
let component = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?;
let endpoint = component
Expand All @@ -94,8 +95,8 @@ pub async fn run(
.await?
.endpoint(endpoint_id.name);

if let Some(etcd_client) = distributed.etcd_client() {
let network_name = endpoint.subject();
if let Some(etcd_client) = etcd_client {
let network_name = endpoint.subject_to(etcd_client.lease_id());
tracing::debug!("Registering with etcd as {network_name}");
etcd_client
.kv_create(
Expand Down
6 changes: 3 additions & 3 deletions launch/dynamo-run/src/input/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ use dynamo_runtime::{
DistributedRuntime, Runtime,
};

use crate::EngineConfig;
use crate::{EngineConfig, Flags};

/// Build and run an HTTP service
pub async fn run(
runtime: Runtime,
http_port: u16,
flags: Flags,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder()
.port(http_port)
.port(flags.http_port)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.build()?;
Expand Down
7 changes: 4 additions & 3 deletions launch/dynamo-run/src/input/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,24 @@ use futures::StreamExt;
use std::io::{ErrorKind, Write};

use crate::input::common;
use crate::EngineConfig;
use crate::{EngineConfig, Flags};

/// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this
const MAX_TOKENS: u32 = 8192;

pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken,
flags: Flags,
single_prompt: Option<String>,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
let (service_name, engine, inspect_template): (
String,
OpenAIChatCompletionsStreamingEngine,
bool,
) = common::prepare_engine(runtime.clone(), engine_config).await?;
) = common::prepare_engine(runtime, flags, engine_config).await?;
main_loop(
cancel_token,
&service_name,
Expand Down
Loading
Loading