@@ -180,7 +180,15 @@ def set_secret(service_client, arn, token):
180
180
if not conn :
181
181
logger .error ("setSecret: Unable to log into database using current credentials for secret %s" % arn )
182
182
raise ValueError ("Unable to log into database using current credentials for secret %s" % arn )
183
- conn .close ()
183
+ # get hostname of existing user
184
+ try :
185
+ with conn .cursor () as cur :
186
+ cur .execute ("SELECT CURRENT_USER()" )
187
+ current_user_fullname = cur .fetchone ()[0 ]
188
+ user_hostname = current_user_fullname .split ("@" )[1 ]
189
+ logger .info ("User hostname detected: [%s]" , user_hostname )
190
+ finally :
191
+ conn .close ()
184
192
185
193
# Use the master arn from the current secret to fetch master secret contents
186
194
master_arn = current_dict ['masterarn' ]
@@ -200,41 +208,104 @@ def set_secret(service_client, arn, token):
200
208
# Now set the password to the pending password
201
209
try :
202
210
with conn .cursor () as cur :
203
- cur .execute ("SELECT User FROM mysql.user WHERE User = %s" , pending_dict ['username' ])
211
+ cur .execute (
212
+ query = "SELECT User FROM mysql.user WHERE User = %s AND Host = %s" ,
213
+ args = (pending_dict ["username" ], user_hostname ),
214
+ )
204
215
# Create the user if it does not exist
205
216
if cur .rowcount == 0 :
206
- cur .execute ("CREATE USER %s IDENTIFIED BY %s" , (pending_dict ['username' ], pending_dict ['password' ]))
217
+ cur .execute (
218
+ query = "CREATE USER %s@%s IDENTIFIED BY %s" ,
219
+ args = (
220
+ pending_dict ["username" ],
221
+ user_hostname ,
222
+ pending_dict ["password" ],
223
+ ),
224
+ )
207
225
208
226
# Copy grants to the new user
209
- cur .execute ("SHOW GRANTS FOR %s" , current_dict ['username' ])
227
+ cur .execute (
228
+ query = "SHOW GRANTS FOR %s@%s" ,
229
+ args = (current_dict ["username" ], user_hostname ),
230
+ )
210
231
for row in cur .fetchall ():
211
- grant = row [0 ].split (' TO ' )
212
- new_grant_escaped = grant [0 ].replace ('%' , '%%' ) # % is a special character in Python format strings.
213
- cur .execute (new_grant_escaped + " TO %s" , (pending_dict ['username' ],))
232
+ grant = row [0 ].split (" TO " )
233
+ new_grant_escaped = grant [0 ].replace (
234
+ "%" , "%%"
235
+ ) # % is a special character in Python format strings.
236
+ cur .execute (
237
+ query = new_grant_escaped + " TO %s@%s" ,
238
+ args = (pending_dict ["username" ], user_hostname ),
239
+ )
214
240
215
241
# Get the version of MySQL
216
242
cur .execute ("SELECT VERSION()" )
217
243
ver = cur .fetchone ()[0 ]
218
244
219
245
# Copy TLS options to the new user
220
246
escaped_encryption_statement = get_escaped_encryption_statement (ver )
221
- cur .execute ("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s" , current_dict ['username' ])
247
+ cur .execute (
248
+ query = "SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s AND Host = %s" ,
249
+ args = (
250
+ current_dict ["username" ],
251
+ user_hostname ,
252
+ ),
253
+ )
222
254
tls_options = cur .fetchone ()
223
255
ssl_type = tls_options [0 ]
224
256
if not ssl_type :
225
- cur .execute (escaped_encryption_statement + " NONE" , pending_dict ['username' ])
226
- elif ssl_type == "ANY" :
227
- cur .execute (escaped_encryption_statement + " SSL" , pending_dict ['username' ])
228
- elif ssl_type == "X509" :
229
- cur .execute (escaped_encryption_statement + " X509" , pending_dict ['username' ])
257
+ cur .execute (
258
+ query = escaped_encryption_statement + " NONE" ,
259
+ args = (
260
+ pending_dict ["username" ],
261
+ user_hostname ,
262
+ ),
263
+ )
264
+ elif "ANY" == ssl_type :
265
+ cur .execute (
266
+ query = escaped_encryption_statement + " SSL" ,
267
+ args = (
268
+ pending_dict ["username" ],
269
+ user_hostname ,
270
+ ),
271
+ )
272
+ elif "X509" == ssl_type :
273
+ cur .execute (
274
+ query = escaped_encryption_statement + " X509" ,
275
+ args = (
276
+ pending_dict ["username" ],
277
+ user_hostname ,
278
+ ),
279
+ )
230
280
else :
231
- cur .execute (escaped_encryption_statement + " CIPHER %s AND ISSUER %s AND SUBJECT %s" , (pending_dict ['username' ], tls_options [1 ], tls_options [2 ], tls_options [3 ]))
281
+ cur .execute (
282
+ query = escaped_encryption_statement
283
+ + " CIPHER %s AND ISSUER %s AND SUBJECT %s" ,
284
+ args = (
285
+ pending_dict ["username" ],
286
+ user_hostname ,
287
+ tls_options [1 ],
288
+ tls_options [2 ],
289
+ tls_options [3 ],
290
+ ),
291
+ )
232
292
233
293
# Set the password for the user and commit
234
- password_option = get_password_option (ver )
235
- cur .execute ("SET PASSWORD FOR %s = " + password_option , (pending_dict ['username' ], pending_dict ['password' ]))
294
+ password_option = get_password_option (version = ver )
295
+ cur .execute (
296
+ query = "SET PASSWORD FOR %s@%s = " + password_option ,
297
+ args = (
298
+ pending_dict ["username" ],
299
+ user_hostname ,
300
+ pending_dict ["password" ],
301
+ ),
302
+ )
236
303
conn .commit ()
237
- logger .info ("setSecret: Successfully set password for %s in MySQL DB for secret arn %s." % (pending_dict ['username' ], arn ))
304
+ logger .info (
305
+ "setSecret: Successfully set password for %s in MySQL DB for secret arn %s." ,
306
+ pending_dict ["username" ],
307
+ arn ,
308
+ )
238
309
finally :
239
310
conn .close ()
240
311
@@ -535,9 +606,9 @@ def get_escaped_encryption_statement(version):
535
606
536
607
"""
537
608
if version .startswith ("5.6" ):
538
- return "GRANT USAGE ON *.* TO %s@'%%' REQUIRE"
609
+ return "GRANT USAGE ON *.* TO %s@%s REQUIRE"
539
610
else :
540
- return "ALTER USER %s@'%%' REQUIRE"
611
+ return "ALTER USER %s@%s REQUIRE"
541
612
542
613
543
614
def is_rds_replica_database (replica_dict , master_dict ):
0 commit comments