Skip to content

Commit 9a2cde5

Browse files
committed
Add multihost support to mysql rotation lambda
1 parent e3b3d6b commit 9a2cde5

File tree

1 file changed

+90
-19
lines changed

1 file changed

+90
-19
lines changed

SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,15 @@ def set_secret(service_client, arn, token):
180180
if not conn:
181181
logger.error("setSecret: Unable to log into database using current credentials for secret %s" % arn)
182182
raise ValueError("Unable to log into database using current credentials for secret %s" % arn)
183-
conn.close()
183+
# get hostname of existing user
184+
try:
185+
with conn.cursor() as cur:
186+
cur.execute("SELECT CURRENT_USER()")
187+
current_user_fullname = cur.fetchone()[0]
188+
user_hostname = current_user_fullname.split("@")[1]
189+
logger.info("User hostname detected: [%s]", user_hostname)
190+
finally:
191+
conn.close()
184192

185193
# Use the master arn from the current secret to fetch master secret contents
186194
master_arn = current_dict['masterarn']
@@ -200,41 +208,104 @@ def set_secret(service_client, arn, token):
200208
# Now set the password to the pending password
201209
try:
202210
with conn.cursor() as cur:
203-
cur.execute("SELECT User FROM mysql.user WHERE User = %s", pending_dict['username'])
211+
cur.execute(
212+
query="SELECT User FROM mysql.user WHERE User = %s AND Host = %s",
213+
args=(pending_dict["username"], user_hostname),
214+
)
204215
# Create the user if it does not exist
205216
if cur.rowcount == 0:
206-
cur.execute("CREATE USER %s IDENTIFIED BY %s", (pending_dict['username'], pending_dict['password']))
217+
cur.execute(
218+
query="CREATE USER %s@%s IDENTIFIED BY %s",
219+
args=(
220+
pending_dict["username"],
221+
user_hostname,
222+
pending_dict["password"],
223+
),
224+
)
207225

208226
# Copy grants to the new user
209-
cur.execute("SHOW GRANTS FOR %s", current_dict['username'])
227+
cur.execute(
228+
query="SHOW GRANTS FOR %s@%s",
229+
args=(current_dict["username"], user_hostname),
230+
)
210231
for row in cur.fetchall():
211-
grant = row[0].split(' TO ')
212-
new_grant_escaped = grant[0].replace('%', '%%') # % is a special character in Python format strings.
213-
cur.execute(new_grant_escaped + " TO %s", (pending_dict['username'],))
232+
grant = row[0].split(" TO ")
233+
new_grant_escaped = grant[0].replace(
234+
"%", "%%"
235+
) # % is a special character in Python format strings.
236+
cur.execute(
237+
query=new_grant_escaped + " TO %s@%s",
238+
args=(pending_dict["username"], user_hostname),
239+
)
214240

215241
# Get the version of MySQL
216242
cur.execute("SELECT VERSION()")
217243
ver = cur.fetchone()[0]
218244

219245
# Copy TLS options to the new user
220246
escaped_encryption_statement = get_escaped_encryption_statement(ver)
221-
cur.execute("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s", current_dict['username'])
247+
cur.execute(
248+
query="SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s AND Host = %s",
249+
args=(
250+
current_dict["username"],
251+
user_hostname,
252+
),
253+
)
222254
tls_options = cur.fetchone()
223255
ssl_type = tls_options[0]
224256
if not ssl_type:
225-
cur.execute(escaped_encryption_statement + " NONE", pending_dict['username'])
226-
elif ssl_type == "ANY":
227-
cur.execute(escaped_encryption_statement + " SSL", pending_dict['username'])
228-
elif ssl_type == "X509":
229-
cur.execute(escaped_encryption_statement + " X509", pending_dict['username'])
257+
cur.execute(
258+
query=escaped_encryption_statement + " NONE",
259+
args=(
260+
pending_dict["username"],
261+
user_hostname,
262+
),
263+
)
264+
elif "ANY" == ssl_type:
265+
cur.execute(
266+
query=escaped_encryption_statement + " SSL",
267+
args=(
268+
pending_dict["username"],
269+
user_hostname,
270+
),
271+
)
272+
elif "X509" == ssl_type:
273+
cur.execute(
274+
query=escaped_encryption_statement + " X509",
275+
args=(
276+
pending_dict["username"],
277+
user_hostname,
278+
),
279+
)
230280
else:
231-
cur.execute(escaped_encryption_statement + " CIPHER %s AND ISSUER %s AND SUBJECT %s", (pending_dict['username'], tls_options[1], tls_options[2], tls_options[3]))
281+
cur.execute(
282+
query=escaped_encryption_statement
283+
+ " CIPHER %s AND ISSUER %s AND SUBJECT %s",
284+
args=(
285+
pending_dict["username"],
286+
user_hostname,
287+
tls_options[1],
288+
tls_options[2],
289+
tls_options[3],
290+
),
291+
)
232292

233293
# Set the password for the user and commit
234-
password_option = get_password_option(ver)
235-
cur.execute("SET PASSWORD FOR %s = " + password_option, (pending_dict['username'], pending_dict['password']))
294+
password_option = get_password_option(version=ver)
295+
cur.execute(
296+
query="SET PASSWORD FOR %s@%s = " + password_option,
297+
args=(
298+
pending_dict["username"],
299+
user_hostname,
300+
pending_dict["password"],
301+
),
302+
)
236303
conn.commit()
237-
logger.info("setSecret: Successfully set password for %s in MySQL DB for secret arn %s." % (pending_dict['username'], arn))
304+
logger.info(
305+
"setSecret: Successfully set password for %s in MySQL DB for secret arn %s.",
306+
pending_dict["username"],
307+
arn,
308+
)
238309
finally:
239310
conn.close()
240311

@@ -535,9 +606,9 @@ def get_escaped_encryption_statement(version):
535606
536607
"""
537608
if version.startswith("5.6"):
538-
return "GRANT USAGE ON *.* TO %s@'%%' REQUIRE"
609+
return "GRANT USAGE ON *.* TO %s@%s REQUIRE"
539610
else:
540-
return "ALTER USER %s@'%%' REQUIRE"
611+
return "ALTER USER %s@%s REQUIRE"
541612

542613

543614
def is_rds_replica_database(replica_dict, master_dict):

0 commit comments

Comments
 (0)