Skip to content

[AQUA][MMD] Enhance FT weights. #1206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 15 additions & 34 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ads.aqua import logger
from ads.aqua.app import AquaApp
from ads.aqua.common.entities import AquaMultiModelRef, LoraModuleSpec
from ads.aqua.common.entities import AquaMultiModelRef
from ads.aqua.common.enums import (
ConfigFolder,
CustomInferenceContainerTypeFamily,
Expand Down Expand Up @@ -84,7 +84,6 @@
)
from ads.aqua.model.enums import MultiModelSupportedTaskType
from ads.aqua.model.utils import (
extract_base_model_from_ft,
extract_fine_tune_artifacts_path,
)
from ads.common.auth import default_signer
Expand Down Expand Up @@ -238,6 +237,7 @@ def create(
def create_multi(
self,
models: List[AquaMultiModelRef],
model_details: Dict[str, DataScienceModel],
project_id: Optional[str] = None,
compartment_id: Optional[str] = None,
freeform_tags: Optional[Dict] = None,
Expand All @@ -251,6 +251,8 @@ def create_multi(
----------
models : List[AquaMultiModelRef]
List of AquaMultiModelRef instances for creating a multi-model group.
model_details :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for base models or just FT models- the description is not clear.

A dict that contains model id and its corresponding data science model.
project_id : Optional[str]
The project ID for the multi-model group.
compartment_id : Optional[str]
Expand Down Expand Up @@ -298,7 +300,7 @@ def create_multi(
# Process each model in the input list
for model in models:
# Retrieve model metadata from the Model Catalog using the model ID
source_model = DataScienceModel.from_id(model.model_id)
source_model: DataScienceModel = model_details.get(model.model_id)
display_name = source_model.display_name
model_file_description = source_model.model_file_description
# If model_name is not explicitly provided, use the model's display name
Expand All @@ -310,42 +312,21 @@ def create_multi(
"Please register the model first."
)

# Check if the model is a fine-tuned model based on its tags
is_fine_tuned_model = (
Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags
)

base_model_artifact_path = ""
fine_tune_path = ""

if is_fine_tuned_model:
# Extract artifact paths for the base and fine-tuned model
base_model_artifact_path, fine_tune_path = (
extract_fine_tune_artifacts_path(source_model)
)

# Create a single LoRA module specification for the fine-tuned model
# TODO: Support multiple LoRA modules in the future
model.fine_tune_weights = [
LoraModuleSpec(
model_id=model.model_id,
model_name=model.model_name,
model_path=fine_tune_path,
if model.fine_tune_weights:
for loral_module_spec in model.fine_tune_weights:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: lora_module_spec? (not sure if loral is intentional)

fine_tune_model: DataScienceModel = model_details.get(
loral_module_spec.model_id
)
]

# Use the LoRA module name as the model's display name
display_name = model.model_name

# Temporarily override model ID and name with those of the base model
# TODO: Revisit this logic once proper base/FT model handling is implemented
model.model_id, model.model_name = extract_base_model_from_ft(
source_model
)
# Extract artifact paths for the base and fine-tuned model
base_model_artifact_path, fine_tune_path = (
extract_fine_tune_artifacts_path(fine_tune_model)
)
loral_module_spec.model_path = fine_tune_path
else:
# For base models, use the original artifact path
base_model_artifact_path = source_model.artifact
display_name = model.model_name

if not base_model_artifact_path:
# Fail if no artifact is found for the base model model
Expand All @@ -356,7 +337,7 @@ def create_multi(

# Update the artifact path in the model configuration
model.artifact_location = base_model_artifact_path
display_name_list.append(display_name)
display_name_list.append(model.model_name)

# Extract model task metadata from source model
self._extract_model_task(model, source_model)
Expand Down
32 changes: 27 additions & 5 deletions ads/aqua/modeldeployment/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,17 +356,39 @@ def _fetch_deployment_configs_concurrently(
self, model_ids: List[str]
) -> Dict[str, AquaDeploymentConfig]:
"""Fetches deployment configurations in parallel using ThreadPoolExecutor."""
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
results = executor.map(
self._fetch_deployment_config_from_metadata_and_oss,
model_ids,
results: Dict[str, ModelConfigResult] = (
self.fetch_data_science_resources_concurrently(
model_ids, self._fetch_deployment_config_from_metadata_and_oss
)
)

return {
model_id: AquaDeploymentConfig(**config.config)
for model_id, config in zip(model_ids, results)
for model_id, config in results.items()
}

@classmethod
def fetch_data_science_resources_concurrently(
cls, model_ids: List[str], function
) -> dict:
"""Fetches data science resources in parallel using ThreadPoolExecutor.

Parameters
----------
function:
A callable that will take as many arguments as there are
passed iterables.

Returns
-------
dict:
A dict of model id and its corresponding data science resource.
"""
with ThreadPoolExecutor(max_workers=cls.MAX_WORKERS) as executor:
results = executor.map(function, model_ids)

return dict(zip(model_ids, results))

def _fetch_deployment_config_from_metadata_and_oss(
self, model_id: str
) -> ModelConfigResult:
Expand Down
28 changes: 25 additions & 3 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import copy
import json
import shlex
from datetime import datetime, timedelta
Expand Down Expand Up @@ -214,11 +215,31 @@ def create(
container_config=container_config,
)
else:
model_ids = [model.model_id for model in create_deployment_details.models]
base_model_ids = [
model.model_id for model in create_deployment_details.models
]

model_ids = copy.deepcopy(base_model_ids)
for model in create_deployment_details.models:
if model.fine_tune_weights:
model_ids.extend(
[
fine_tune_model.model_id
for fine_tune_model in model.fine_tune_weights
]
)

model_details = MultiModelDeploymentConfigLoader.fetch_data_science_resources_concurrently(
model_ids, DataScienceModel.from_id
)
try:
create_deployment_details.validate_input_models(model_details)
except ConfigValidationError as err:
raise AquaValueError(f"{err}") from err

try:
model_config_summary = self.get_multimodel_deployment_config(
model_ids=model_ids, compartment_id=compartment_id
model_ids=base_model_ids, compartment_id=compartment_id
)
if not model_config_summary.gpu_allocation:
raise AquaValueError(model_config_summary.error_message)
Expand Down Expand Up @@ -292,11 +313,12 @@ def create(
)

logger.debug(
f"Multi models ({model_ids}) provided. Delegating to multi model creation method."
f"Multi models ({base_model_ids}) provided. Delegating to multi model creation method."
)

aqua_model = model_app.create_multi(
models=create_deployment_details.models,
model_details=model_details,
compartment_id=compartment_id,
project_id=project_id,
freeform_tags=freeform_tags,
Expand Down
57 changes: 57 additions & 0 deletions ads/aqua/modeldeployment/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
from ads.aqua.config.utils.serializer import Serializable
from ads.aqua.constants import UNKNOWN_DICT
from ads.aqua.data import AquaResourceIdentifier
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
from ads.aqua.modeldeployment.config_loader import (
ConfigurationItem,
ModelDeploymentConfigSummary,
)
from ads.common.serializer import DataClassSerializable
from ads.common.utils import UNKNOWN, get_console_link
from ads.model.model_metadata import ModelCustomMetadataItem


class ConfigValidationError(Exception):
Expand Down Expand Up @@ -474,6 +476,61 @@ def validate_multimodel_deployment_feasibility(
logger.error(error_message)
raise ConfigValidationError(error_message)

def validate_input_models(self, model_details: dict):
"""Validates whether the input models list is feasible for a multi-model deployment.

Validation Criteria:
- Ensures that the base model is provided.
- Ensures that the base model is active.
- Ensures that the model id provided in `fine_tune_weights` is fine tuned model.
- Ensures that the model id provided in `fine_tune_weights` belongs to the same base model.

Parameters
----------
model_details:
A dict that contains model id and its corresponding data science model.

Raises
------
ConfigValidationError:
- If invalid base model id is provided.
- If base model is not active.
- If model id provided in `fine_tune_weights` is not fine tuned model.
- If model id provided in `fine_tune_weights` doesn't belong to the same base model.
"""
for model in self.models:
base_model_id = model.model_id
base_model = model_details.get(base_model_id)
if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags:
error_message = f"Invalid base model id {base_model_id}. Specify base model id `model_id` in `models` input, not fine tuned model id."
logger.error(error_message)
raise ConfigValidationError(error_message)
if base_model.lifecycle_state != "ACTIVE":
error_message = f"Invalid base model id {base_model_id}. Specify active base model id `model_id` in `models` input."
logger.error(error_message)
Copy link
Member

@elizjo elizjo Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_model_id + base_model_name would be more descriptive in the error message. "Invalid model_id specified in models input: {base_model_name} with OCID {base_model_ocid} is not a base model.
Please provide the OCID of a base model, not a fine-tuned model, for model_id. "

raise ConfigValidationError(error_message)
if model.fine_tune_weights:
for loral_module_spec in model.fine_tune_weights:
fine_tune_model_id = loral_module_spec.model_id
fine_tune_model = model_details.get(fine_tune_model_id)
if (
Tags.AQUA_FINE_TUNED_MODEL_TAG
not in fine_tune_model.freeform_tags
):
error_message = f"Invalid fine tune model id {fine_tune_model_id} in `models.fine_tune_weights` input. Fine tune model must have tag {Tags.AQUA_FINE_TUNED_MODEL_TAG}."
logger.error(error_message)
raise ConfigValidationError(error_message)
fine_tune_base_model_id = fine_tune_model.custom_metadata_list.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use extract_base_model_from_ft() here but this also ok

FineTuneCustomMetadata.FINE_TUNE_SOURCE,
ModelCustomMetadataItem(
key=FineTuneCustomMetadata.FINE_TUNE_SOURCE
),
).value
if fine_tune_base_model_id != base_model_id:
error_message = f"Invalid fine tune model id {fine_tune_model_id} in `models.fine_tune_weights` input. Fine tune model must belong to base model {base_model_id}."
logger.error(error_message)
Copy link
Member

@elizjo elizjo Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: may be easier for user if base model name was shown in error_message along with base_model_id:

The specified fine-tuned model '{fine_tune_model_name}' (OCID: {fine_tune_model_id}) does not belong to the required base model '{base_model_name}' (OCID: {base_model_id}).

raise ConfigValidationError(error_message)

class Config:
extra = "allow"
protected_namespaces = ()
16 changes: 7 additions & 9 deletions ads/aqua/modeldeployment/model_group_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ def dedup_lora_modules(cls, fine_tune_weights: List[LoraModuleSpec]):
for module in fine_tune_weights or []:
name = getattr(module, "model_name", None)
if not name:
logger.warning("Fine-tuned model in AquaMultiModelRef is missing model_name.")
logger.warning(
"Fine-tuned model in AquaMultiModelRef is missing model_name."
)
continue
if name in seen:
logger.warning(f"Duplicate LoRA Module detected: {name!r} (skipping duplicate).")
logger.warning(
f"Duplicate LoRA Module detected: {name!r} (skipping duplicate)."
)
continue
seen.add(name)
unique_modules.append(module)
Expand Down Expand Up @@ -169,14 +173,8 @@ def _merge_gpu_count_params(
model, container_params, container_type_key
)

model_id = (
model.fine_tune_weights[0].model_id
if model.fine_tune_weights
else model.model_id
)

deployment_config = model_config_summary.deployment_config.get(
model_id, AquaDeploymentConfig()
model.model_id, AquaDeploymentConfig()
).configuration.get(
create_deployment_details.instance_shape, ConfigurationItem()
)
Expand Down
18 changes: 11 additions & 7 deletions tests/unitary/with_extras/aqua/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,8 +1789,16 @@ def test_create_deployment_for_tei_byoc_embedding_model(
@patch(
"ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_multimodel_deployment_feasibility"
)
@patch(
"ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_input_models"
)
@patch(
"ads.aqua.modeldeployment.config_loader.MultiModelDeploymentConfigLoader.fetch_data_science_resources_concurrently"
)
def test_create_deployment_for_multi_model(
self,
mock_fetch_data_science_resources_concurrently,
mock_validate_input_models,
mock_validate_multimodel_deployment_feasibility,
mock_get_deployment_config,
mock_deploy,
Expand Down Expand Up @@ -1902,13 +1910,9 @@ def test_create_deployment_for_multi_model(
predict_log_id="ocid1.log.oc1.<region>.<OCID>",
)

mock_create_multi.assert_called_with(
models=[model_info_1, model_info_2, model_info_3],
compartment_id=TestDataset.USER_COMPARTMENT_ID,
project_id=TestDataset.USER_PROJECT_ID,
freeform_tags=None,
defined_tags=None,
)
mock_create_multi.assert_called()
mock_fetch_data_science_resources_concurrently.assert_called()
mock_validate_input_models.assert_called()
mock_get_container_image.assert_called()
mock_deploy.assert_called()

Expand Down
Loading
Loading