Skip to content

Commit 22a1bc3

Browse files
authored
fix: perferred model provider not match with provider. (#18282)
Signed-off-by: -LAN- <[email protected]>
1 parent caa179a commit 22a1bc3

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

api/core/provider_manager.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
124124

125125
# Get All preferred provider types of the workspace
126126
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
127+
# Ensure that both the original provider name and its ModelProviderID string representation
128+
# are present in the dictionary to handle cases where either form might be used
129+
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
130+
provider_id = ModelProviderID(provider_name)
131+
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
132+
# Add the ModelProviderID string representation if it's not already present
133+
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
134+
provider_name_to_preferred_model_provider_records_dict[provider_name]
135+
)
127136

128137
# Get All provider model settings
129138
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
@@ -497,8 +506,8 @@ def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[L
497506

498507
@staticmethod
499508
def _init_trial_provider_records(
500-
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
501-
) -> dict[str, list]:
509+
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
510+
) -> dict[str, list[Provider]]:
502511
"""
503512
Initialize trial provider records if not exists.
504513
@@ -532,7 +541,7 @@ def _init_trial_provider_records(
532541
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
533542
try:
534543
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
535-
provider_record = Provider(
544+
new_provider_record = Provider(
536545
tenant_id=tenant_id,
537546
# TODO: Use provider name with prefix after the data migration.
538547
provider_name=ModelProviderID(provider_name).provider_name,
@@ -542,11 +551,12 @@ def _init_trial_provider_records(
542551
quota_used=0,
543552
is_valid=True,
544553
)
545-
db.session.add(provider_record)
554+
db.session.add(new_provider_record)
546555
db.session.commit()
556+
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
547557
except IntegrityError:
548558
db.session.rollback()
549-
provider_record = (
559+
existed_provider_record = (
550560
db.session.query(Provider)
551561
.filter(
552562
Provider.tenant_id == tenant_id,
@@ -556,11 +566,14 @@ def _init_trial_provider_records(
556566
)
557567
.first()
558568
)
559-
if provider_record and not provider_record.is_valid:
560-
provider_record.is_valid = True
569+
if not existed_provider_record:
570+
continue
571+
572+
if not existed_provider_record.is_valid:
573+
existed_provider_record.is_valid = True
561574
db.session.commit()
562575

563-
provider_name_to_provider_records_dict[provider_name].append(provider_record)
576+
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
564577

565578
return provider_name_to_provider_records_dict
566579

0 commit comments

Comments
 (0)