-
Notifications
You must be signed in to change notification settings - Fork 49
[AQUA][MMD] Enhance FT weights. #1206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
|
||
from ads.aqua import logger | ||
from ads.aqua.app import AquaApp | ||
from ads.aqua.common.entities import AquaMultiModelRef, LoraModuleSpec | ||
from ads.aqua.common.entities import AquaMultiModelRef | ||
from ads.aqua.common.enums import ( | ||
ConfigFolder, | ||
CustomInferenceContainerTypeFamily, | ||
|
@@ -84,7 +84,6 @@ | |
) | ||
from ads.aqua.model.enums import MultiModelSupportedTaskType | ||
from ads.aqua.model.utils import ( | ||
extract_base_model_from_ft, | ||
extract_fine_tune_artifacts_path, | ||
) | ||
from ads.common.auth import default_signer | ||
|
@@ -238,6 +237,7 @@ def create( | |
def create_multi( | ||
self, | ||
models: List[AquaMultiModelRef], | ||
model_details: Dict[str, DataScienceModel], | ||
project_id: Optional[str] = None, | ||
compartment_id: Optional[str] = None, | ||
freeform_tags: Optional[Dict] = None, | ||
|
@@ -251,6 +251,8 @@ def create_multi( | |
---------- | ||
models : List[AquaMultiModelRef] | ||
List of AquaMultiModelRef instances for creating a multi-model group. | ||
model_details : | ||
A dict that contains model id and its corresponding data science model. | ||
project_id : Optional[str] | ||
The project ID for the multi-model group. | ||
compartment_id : Optional[str] | ||
|
@@ -298,7 +300,7 @@ def create_multi( | |
# Process each model in the input list | ||
for model in models: | ||
# Retrieve model metadata from the Model Catalog using the model ID | ||
source_model = DataScienceModel.from_id(model.model_id) | ||
source_model: DataScienceModel = model_details.get(model.model_id) | ||
display_name = source_model.display_name | ||
model_file_description = source_model.model_file_description | ||
# If model_name is not explicitly provided, use the model's display name | ||
|
@@ -310,42 +312,21 @@ def create_multi( | |
"Please register the model first." | ||
) | ||
|
||
# Check if the model is a fine-tuned model based on its tags | ||
is_fine_tuned_model = ( | ||
Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags | ||
) | ||
|
||
base_model_artifact_path = "" | ||
fine_tune_path = "" | ||
|
||
if is_fine_tuned_model: | ||
# Extract artifact paths for the base and fine-tuned model | ||
base_model_artifact_path, fine_tune_path = ( | ||
extract_fine_tune_artifacts_path(source_model) | ||
) | ||
|
||
# Create a single LoRA module specification for the fine-tuned model | ||
# TODO: Support multiple LoRA modules in the future | ||
model.fine_tune_weights = [ | ||
LoraModuleSpec( | ||
model_id=model.model_id, | ||
model_name=model.model_name, | ||
model_path=fine_tune_path, | ||
if model.fine_tune_weights: | ||
for loral_module_spec in model.fine_tune_weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: lora_module_spec? (not sure if loral is intentional) |
||
fine_tune_model: DataScienceModel = model_details.get( | ||
loral_module_spec.model_id | ||
) | ||
] | ||
|
||
# Use the LoRA module name as the model's display name | ||
display_name = model.model_name | ||
|
||
# Temporarily override model ID and name with those of the base model | ||
# TODO: Revisit this logic once proper base/FT model handling is implemented | ||
model.model_id, model.model_name = extract_base_model_from_ft( | ||
source_model | ||
) | ||
# Extract artifact paths for the base and fine-tuned model | ||
base_model_artifact_path, fine_tune_path = ( | ||
extract_fine_tune_artifacts_path(fine_tune_model) | ||
) | ||
loral_module_spec.model_path = fine_tune_path | ||
else: | ||
# For base models, use the original artifact path | ||
base_model_artifact_path = source_model.artifact | ||
display_name = model.model_name | ||
|
||
if not base_model_artifact_path: | ||
# Fail if no artifact is found for the base model model | ||
|
@@ -356,7 +337,7 @@ def create_multi( | |
|
||
# Update the artifact path in the model configuration | ||
model.artifact_location = base_model_artifact_path | ||
display_name_list.append(display_name) | ||
display_name_list.append(model.model_name) | ||
|
||
# Extract model task metadata from source model | ||
self._extract_model_task(model, source_model) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,12 +13,14 @@ | |
from ads.aqua.config.utils.serializer import Serializable | ||
from ads.aqua.constants import UNKNOWN_DICT | ||
from ads.aqua.data import AquaResourceIdentifier | ||
from ads.aqua.finetuning.constants import FineTuneCustomMetadata | ||
from ads.aqua.modeldeployment.config_loader import ( | ||
ConfigurationItem, | ||
ModelDeploymentConfigSummary, | ||
) | ||
from ads.common.serializer import DataClassSerializable | ||
from ads.common.utils import UNKNOWN, get_console_link | ||
from ads.model.model_metadata import ModelCustomMetadataItem | ||
|
||
|
||
class ConfigValidationError(Exception): | ||
|
@@ -474,6 +476,61 @@ def validate_multimodel_deployment_feasibility( | |
logger.error(error_message) | ||
raise ConfigValidationError(error_message) | ||
|
||
def validate_input_models(self, model_details: dict): | ||
"""Validates whether the input models list is feasible for a multi-model deployment. | ||
|
||
Validation Criteria: | ||
- Ensures that the base model is provided. | ||
- Ensures that the base model is active. | ||
- Ensures that the model id provided in `fine_tune_weights` is fine tuned model. | ||
- Ensures that the model id provided in `fine_tune_weights` belongs to the same base model. | ||
|
||
Parameters | ||
---------- | ||
model_details: | ||
A dict that contains model id and its corresponding data science model. | ||
|
||
Raises | ||
------ | ||
ConfigValidationError: | ||
- If invalid base model id is provided. | ||
- If base model is not active. | ||
- If model id provided in `fine_tune_weights` is not fine tuned model. | ||
- If model id provided in `fine_tune_weights` doesn't belong to the same base model. | ||
""" | ||
for model in self.models: | ||
base_model_id = model.model_id | ||
base_model = model_details.get(base_model_id) | ||
if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags: | ||
error_message = f"Invalid base model id {base_model_id}. Specify base model id `model_id` in `models` input, not fine tuned model id." | ||
logger.error(error_message) | ||
raise ConfigValidationError(error_message) | ||
if base_model.lifecycle_state != "ACTIVE": | ||
error_message = f"Invalid base model id {base_model_id}. Specify active base model id `model_id` in `models` input." | ||
logger.error(error_message) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. base_model_id + base_model_name would be more descriptive in the error message. "Invalid model_id specified in models input: {base_model_name} with OCID {base_model_ocid} is not a base model. |
||
raise ConfigValidationError(error_message) | ||
if model.fine_tune_weights: | ||
for loral_module_spec in model.fine_tune_weights: | ||
fine_tune_model_id = loral_module_spec.model_id | ||
fine_tune_model = model_details.get(fine_tune_model_id) | ||
if ( | ||
Tags.AQUA_FINE_TUNED_MODEL_TAG | ||
not in fine_tune_model.freeform_tags | ||
): | ||
error_message = f"Invalid fine tune model id {fine_tune_model_id} in `models.fine_tune_weights` input. Fine tune model must have tag {Tags.AQUA_FINE_TUNED_MODEL_TAG}." | ||
logger.error(error_message) | ||
raise ConfigValidationError(error_message) | ||
fine_tune_base_model_id = fine_tune_model.custom_metadata_list.get( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe use extract_base_model_from_ft() here but this also ok |
||
FineTuneCustomMetadata.FINE_TUNE_SOURCE, | ||
ModelCustomMetadataItem( | ||
key=FineTuneCustomMetadata.FINE_TUNE_SOURCE | ||
), | ||
).value | ||
if fine_tune_base_model_id != base_model_id: | ||
error_message = f"Invalid fine tune model id {fine_tune_model_id} in `models.fine_tune_weights` input. Fine tune model must belong to base model {base_model_id}." | ||
logger.error(error_message) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: may be easier for user if base model name was shown in error_message along with base_model_id: The specified fine-tuned model '{fine_tune_model_name}' (OCID: {fine_tune_model_id}) does not belong to the required base model '{base_model_name}' (OCID: {base_model_id}). |
||
raise ConfigValidationError(error_message) | ||
|
||
class Config: | ||
extra = "allow" | ||
protected_namespaces = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this for base models or just FT models- the description is not clear.