@@ -124,6 +124,15 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
124
124
125
125
# Get All preferred provider types of the workspace
126
126
provider_name_to_preferred_model_provider_records_dict = self ._get_all_preferred_model_providers (tenant_id )
127
+ # Ensure that both the original provider name and its ModelProviderID string representation
128
+ # are present in the dictionary to handle cases where either form might be used
129
+ for provider_name in list (provider_name_to_preferred_model_provider_records_dict .keys ()):
130
+ provider_id = ModelProviderID (provider_name )
131
+ if str (provider_id ) not in provider_name_to_preferred_model_provider_records_dict :
132
+ # Add the ModelProviderID string representation if it's not already present
133
+ provider_name_to_preferred_model_provider_records_dict [str (provider_id )] = (
134
+ provider_name_to_preferred_model_provider_records_dict [provider_name ]
135
+ )
127
136
128
137
# Get All provider model settings
129
138
provider_name_to_provider_model_settings_dict = self ._get_all_provider_model_settings (tenant_id )
@@ -497,8 +506,8 @@ def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[L
497
506
498
507
@staticmethod
499
508
def _init_trial_provider_records (
500
- tenant_id : str , provider_name_to_provider_records_dict : dict [str , list ]
501
- ) -> dict [str , list ]:
509
+ tenant_id : str , provider_name_to_provider_records_dict : dict [str , list [ Provider ] ]
510
+ ) -> dict [str , list [ Provider ] ]:
502
511
"""
503
512
Initialize trial provider records if not exists.
504
513
@@ -532,7 +541,7 @@ def _init_trial_provider_records(
532
541
if ProviderQuotaType .TRIAL not in provider_quota_to_provider_record_dict :
533
542
try :
534
543
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
535
- provider_record = Provider (
544
+ new_provider_record = Provider (
536
545
tenant_id = tenant_id ,
537
546
# TODO: Use provider name with prefix after the data migration.
538
547
provider_name = ModelProviderID (provider_name ).provider_name ,
@@ -542,11 +551,12 @@ def _init_trial_provider_records(
542
551
quota_used = 0 ,
543
552
is_valid = True ,
544
553
)
545
- db .session .add (provider_record )
554
+ db .session .add (new_provider_record )
546
555
db .session .commit ()
556
+ provider_name_to_provider_records_dict [provider_name ].append (new_provider_record )
547
557
except IntegrityError :
548
558
db .session .rollback ()
549
- provider_record = (
559
+ existed_provider_record = (
550
560
db .session .query (Provider )
551
561
.filter (
552
562
Provider .tenant_id == tenant_id ,
@@ -556,11 +566,14 @@ def _init_trial_provider_records(
556
566
)
557
567
.first ()
558
568
)
559
- if provider_record and not provider_record .is_valid :
560
- provider_record .is_valid = True
569
+ if not existed_provider_record :
570
+ continue
571
+
572
+ if not existed_provider_record .is_valid :
573
+ existed_provider_record .is_valid = True
561
574
db .session .commit ()
562
575
563
- provider_name_to_provider_records_dict [provider_name ].append (provider_record )
576
+ provider_name_to_provider_records_dict [provider_name ].append (existed_provider_record )
564
577
565
578
return provider_name_to_provider_records_dict
566
579
0 commit comments