diff --git a/source/shared/core_conn.cpp b/source/shared/core_conn.cpp index da6284dc..a1fe06b0 100644 --- a/source/shared/core_conn.cpp +++ b/source/shared/core_conn.cpp @@ -257,7 +257,9 @@ sqlsrv_conn* core_sqlsrv_connect( _In_ sqlsrv_context& henv_cp, _In_ sqlsrv_cont throw core::CoreException(); } - load_azure_key_vault( conn ); + // After load_azure_key_vault, reset AKV related variables regardless + load_azure_key_vault(conn); + conn->ce_option.akv_reset(); // determine the version of the server we're connected to. The server version is left in the // connection upon return. @@ -292,6 +294,7 @@ sqlsrv_conn* core_sqlsrv_connect( _In_ sqlsrv_context& henv_cp, _In_ sqlsrv_cont throw; } catch( core::CoreException& ) { + conn->ce_option.akv_reset(); conn_str.clear(); conn->invalidate(); throw; @@ -862,6 +865,7 @@ void build_connection_string_and_set_conn_attr( _Inout_ sqlsrv_conn* conn, _Inou } catch( core::CoreException& ) { + conn->ce_option.akv_reset(); throw; } } @@ -984,10 +988,10 @@ void load_azure_key_vault(_Inout_ sqlsrv_conn* conn TSRMLS_DC) throw core::CoreException(); } - char *akv_id = Z_STRVAL_P(conn->ce_option.akv_id); - char *akv_secret = Z_STRVAL_P(conn->ce_option.akv_secret); - unsigned int id_len = static_cast(Z_STRLEN_P(conn->ce_option.akv_id)); - unsigned int key_size = static_cast(Z_STRLEN_P(conn->ce_option.akv_secret)); + char *akv_id = conn->ce_option.akv_id.get(); + char *akv_secret = conn->ce_option.akv_secret.get(); + unsigned int id_len = strnlen_s(akv_id); + unsigned int key_size = strnlen_s(akv_secret); configure_azure_key_vault(conn, AKV_CONFIG_FLAGS, conn->ce_option.akv_mode, 0); configure_azure_key_vault(conn, AKV_CONFIG_PRINCIPALID, akv_id, id_len); @@ -1120,6 +1124,7 @@ void ce_akv_str_set_func::func(_In_ connection_option const* option, _In_ zval* { SQLSRV_ASSERT(Z_TYPE_P(value) == IS_STRING, "Azure Key Vault keywords accept only strings."); + const char *value_str = Z_STRVAL_P(value); size_t value_len = Z_STRLEN_P(value); CHECK_CUSTOM_ERROR(value_len <= 0, conn, SQLSRV_ERROR_KEYSTORE_INVALID_VALUE) { @@ -1130,7 +1135,6 @@ void ce_akv_str_set_func::func(_In_ connection_option const* option, _In_ zval* { case SQLSRV_CONN_OPTION_KEYSTORE_AUTHENTICATION: { - char *value_str = Z_STRVAL_P(value); if (!stricmp(value_str, "KeyVaultPassword")) { conn->ce_option.akv_mode = AKVCFG_AUTHMODE_PASSWORD; } else if (!stricmp(value_str, "KeyVaultClientSecret")) { @@ -1145,14 +1149,19 @@ void ce_akv_str_set_func::func(_In_ connection_option const* option, _In_ zval* break; } case SQLSRV_CONN_OPTION_KEYSTORE_PRINCIPAL_ID: - { - conn->ce_option.akv_id = value; - conn->ce_option.akv_required = true; - break; - } case SQLSRV_CONN_OPTION_KEYSTORE_SECRET: { - conn->ce_option.akv_secret = value; + // Create a new string to save a copy of the zvalue + char *pValue = static_cast(sqlsrv_malloc(value_len + 1)); + memcpy_s(pValue, value_len + 1, value_str, value_len); + pValue[value_len] = '\0'; // this makes sure there will be no trailing garbage + + // This will free the existing memory block before assigning the new pointer -- the user might set the value(s) more than once + if (option->conn_option_key == SQLSRV_CONN_OPTION_KEYSTORE_PRINCIPAL_ID) { + conn->ce_option.akv_id = pValue; + } else { + conn->ce_option.akv_secret = pValue; + } conn->ce_option.akv_required = true; break; } diff --git a/source/shared/core_sqlsrv.h b/source/shared/core_sqlsrv.h index 6c49464c..e6efb3a8 100644 --- a/source/shared/core_sqlsrv.h +++ b/source/shared/core_sqlsrv.h @@ -1055,15 +1055,23 @@ struct stmt_option; // This holds the various details of column encryption. struct col_encryption_option { - bool enabled; // column encryption enabled, false by default - SQLINTEGER akv_mode; - zval_auto_ptr akv_id; - zval_auto_ptr akv_secret; - bool akv_required; + bool enabled; // column encryption enabled, false by default + SQLINTEGER akv_mode; + sqlsrv_malloc_auto_ptr akv_id; + sqlsrv_malloc_auto_ptr akv_secret; + bool akv_required; col_encryption_option() : enabled( false ), akv_mode(-1), akv_required( false ) { } + + void akv_reset() + { + akv_id.reset(); + akv_secret.reset(); + akv_required = false; + akv_mode = -1; + } }; // *** connection resource structure *** diff --git a/test/functional/pdo_sqlsrv/pdo_ae_azure_key_vault_keywords.phpt b/test/functional/pdo_sqlsrv/pdo_ae_azure_key_vault_keywords.phpt index 514bab24..5f0c4068 100644 --- a/test/functional/pdo_sqlsrv/pdo_ae_azure_key_vault_keywords.phpt +++ b/test/functional/pdo_sqlsrv/pdo_ae_azure_key_vault_keywords.phpt @@ -1,7 +1,7 @@ --TEST-- Test connection keywords for Azure Key Vault for Always Encrypted. --SKIPIF-- - + --FILE-- + --FILE--