From 93b9938e0231a6ad7be179546196191425f94db8 Mon Sep 17 00:00:00 2001 From: Jenny Tam Date: Fri, 13 Apr 2018 15:06:10 -0700 Subject: [PATCH] Replace most strlen with strnlen_s --- source/pdo_sqlsrv/pdo_dbh.cpp | 4 ++-- source/pdo_sqlsrv/pdo_util.cpp | 8 ++++---- source/shared/FormattedPrint.cpp | 2 +- source/shared/StringFunctions.cpp | 6 ++++++ source/shared/StringFunctions.h | 2 ++ source/shared/core_conn.cpp | 12 ++++++------ source/shared/core_sqlsrv.h | 7 +++---- source/shared/localizationimpl.cpp | 6 +++--- 8 files changed, 27 insertions(+), 20 deletions(-) diff --git a/source/pdo_sqlsrv/pdo_dbh.cpp b/source/pdo_sqlsrv/pdo_dbh.cpp index 2a143629..b1208d09 100644 --- a/source/pdo_sqlsrv/pdo_dbh.cpp +++ b/source/pdo_sqlsrv/pdo_dbh.cpp @@ -1253,7 +1253,7 @@ char * pdo_sqlsrv_dbh_last_id( _Inout_ pdo_dbh_t *dbh, _In_z_ const char *name, else { char* quoted_table = NULL; size_t quoted_len = 0; - int quoted = pdo_sqlsrv_dbh_quote( dbh, name, strlen( name ), "ed_table, "ed_len, PDO_PARAM_NULL TSRMLS_CC ); + int quoted = pdo_sqlsrv_dbh_quote( dbh, name, strnlen_s( name ), "ed_table, "ed_len, PDO_PARAM_NULL TSRMLS_CC ); SQLSRV_ASSERT( quoted, "PDO::lastInsertId failed to quote the table name."); snprintf( last_insert_id_query, LAST_INSERT_ID_QUERY_MAX_LEN, SEQUENCE_CURRENT_VALUE_QUERY, quoted_table ); sqlsrv_free( quoted_table ); @@ -1270,7 +1270,7 @@ char * pdo_sqlsrv_dbh_last_id( _Inout_ pdo_dbh_t *dbh, _In_z_ const char *name, sqlsrv_malloc_auto_ptr wsql_string; unsigned int wsql_len; - wsql_string = utf16_string_from_mbcs_string( SQLSRV_ENCODING_CHAR, reinterpret_cast( last_insert_id_query ), static_cast( strlen( last_insert_id_query )), &wsql_len ); + wsql_string = utf16_string_from_mbcs_string( SQLSRV_ENCODING_CHAR, reinterpret_cast( last_insert_id_query ), static_cast( strnlen_s( last_insert_id_query )), &wsql_len ); CHECK_CUSTOM_ERROR( wsql_string == 0, driver_stmt, SQLSRV_ERROR_QUERY_STRING_ENCODING_TRANSLATE, get_last_error_message() ) { throw core::CoreException(); diff --git a/source/pdo_sqlsrv/pdo_util.cpp b/source/pdo_sqlsrv/pdo_util.cpp index 695b20e6..6103650e 100644 --- a/source/pdo_sqlsrv/pdo_util.cpp +++ b/source/pdo_sqlsrv/pdo_util.cpp @@ -473,7 +473,7 @@ bool pdo_sqlsrv_handle_dbh_error( _Inout_ sqlsrv_context& ctx, _In_opt_ unsigned SQLSRV_ASSERT( err == true, "No ODBC error was found" ); } - SQLSRV_ASSERT(strlen(reinterpret_cast(error->sqlstate)) <= sizeof(dbh->error_code), "Error code overflow"); + SQLSRV_ASSERT(strnlen_s(reinterpret_cast(error->sqlstate)) <= sizeof(dbh->error_code), "Error code overflow"); strcpy_s(dbh->error_code, sizeof(dbh->error_code), reinterpret_cast(error->sqlstate)); switch( dbh->error_mode ) { @@ -486,7 +486,7 @@ bool pdo_sqlsrv_handle_dbh_error( _Inout_ sqlsrv_context& ctx, _In_opt_ unsigned break; case PDO_ERRMODE_WARNING: if( !warning ) { - size_t msg_len = strlen( reinterpret_cast( error->native_message )) + SQL_SQLSTATE_BUFSIZE + size_t msg_len = strnlen_s( reinterpret_cast( error->native_message )) + SQL_SQLSTATE_BUFSIZE + MAX_DIGITS + WARNING_MIN_LENGTH + 1; sqlsrv_malloc_auto_ptr msg; msg = static_cast( sqlsrv_malloc( msg_len ) ); @@ -525,7 +525,7 @@ bool pdo_sqlsrv_handle_stmt_error( _Inout_ sqlsrv_context& ctx, _In_opt_ unsigne SQLSRV_ASSERT( err == true, "No ODBC error was found" ); } - SQLSRV_ASSERT( strlen( reinterpret_cast( error->sqlstate ) ) <= sizeof( pdo_stmt->error_code ), "Error code overflow"); + SQLSRV_ASSERT( strnlen_s( reinterpret_cast( error->sqlstate ) ) <= sizeof( pdo_stmt->error_code ), "Error code overflow"); strcpy_s( pdo_stmt->error_code, sizeof( pdo_stmt->error_code ), reinterpret_cast( error->sqlstate )); switch( pdo_stmt->dbh->error_mode ) { @@ -612,7 +612,7 @@ void pdo_sqlsrv_throw_exception( _In_ sqlsrv_error_const* error TSRMLS_DC ) SQLSRV_ASSERT( zr != FAILURE, "Failed to initialize exception object" ); sqlsrv_malloc_auto_ptr ex_msg; - size_t ex_msg_len = strlen( reinterpret_cast( error->native_message )) + SQL_SQLSTATE_BUFSIZE + + size_t ex_msg_len = strnlen_s( reinterpret_cast( error->native_message )) + SQL_SQLSTATE_BUFSIZE + 12 + 1; // 12 = "SQLSTATE[]: " ex_msg = reinterpret_cast( sqlsrv_malloc( ex_msg_len )); snprintf( ex_msg, ex_msg_len, EXCEPTION_MSG_TEMPLATE, error->sqlstate, error->native_message ); diff --git a/source/shared/FormattedPrint.cpp b/source/shared/FormattedPrint.cpp index 0f6bf36b..6a19a425 100644 --- a/source/shared/FormattedPrint.cpp +++ b/source/shared/FormattedPrint.cpp @@ -709,7 +709,7 @@ int FormattedPrintA( IFormattedPrintOutput * output, const char *format, v ++text.sz; } - textlen = (int)strlen(text.sz); /* compute length of text */ + textlen = (int)strnlen_s(text.sz); /* compute length of text */ } break; diff --git a/source/shared/StringFunctions.cpp b/source/shared/StringFunctions.cpp index 2d8a549a..550c0d35 100644 --- a/source/shared/StringFunctions.cpp +++ b/source/shared/StringFunctions.cpp @@ -138,6 +138,12 @@ int mplat_strcat_s( char * dest, size_t destSize, const char * src ) } return 0; } + +size_t strnlen_s(const char * _Str, size_t _MaxCount) +{ + return (_Str==0) ? 0 : strnlen(_Str, _MaxCount); +} + // // End copy functions //---------------------------------------------------------------------------- diff --git a/source/shared/StringFunctions.h b/source/shared/StringFunctions.h index 0ab97662..a0b31b78 100644 --- a/source/shared/StringFunctions.h +++ b/source/shared/StringFunctions.h @@ -31,6 +31,8 @@ int mplat_memcpy_s(void *_S1, size_t _N1, const void *_S2, size_t _N); int mplat_strcat_s( char *strDestination, size_t numberOfElements, const char *strSource ); int mplat_strcpy_s(char * _Dst, size_t _SizeInBytes, const char * _Src); +size_t strnlen_s(const char * _Str, size_t _MaxCount = INT_MAX); + // Copy #define memcpy_s mplat_memcpy_s #define strcat_s mplat_strcat_s diff --git a/source/shared/core_conn.cpp b/source/shared/core_conn.cpp index f72d4c36..03efa77a 100644 --- a/source/shared/core_conn.cpp +++ b/source/shared/core_conn.cpp @@ -755,31 +755,31 @@ void build_connection_string_and_set_conn_attr( _Inout_ sqlsrv_conn* conn, _Inou try { // Add the server name - common_conn_str_append_func( ODBCConnOptions::SERVER, server, strlen( server ), connection_string TSRMLS_CC ); + common_conn_str_append_func( ODBCConnOptions::SERVER, server, strnlen_s( server ), connection_string TSRMLS_CC ); // if uid is not present then we use trusted connection. - if(uid == NULL || strlen( uid ) == 0 ) { + if(uid == NULL || strnlen_s( uid ) == 0 ) { connection_string += "Trusted_Connection={Yes};"; } else { - bool escaped = core_is_conn_opt_value_escaped( uid, strlen( uid )); + bool escaped = core_is_conn_opt_value_escaped( uid, strnlen_s( uid )); CHECK_CUSTOM_ERROR( !escaped, conn, SQLSRV_ERROR_UID_PWD_BRACES_NOT_ESCAPED ) { throw core::CoreException(); } - common_conn_str_append_func( ODBCConnOptions::UID, uid, strlen( uid ), connection_string TSRMLS_CC ); + common_conn_str_append_func( ODBCConnOptions::UID, uid, strnlen_s( uid ), connection_string TSRMLS_CC ); // if no password was given, then don't add a password to the connection string. Perhaps the UID // given doesn't have a password? if( pwd != NULL ) { - escaped = core_is_conn_opt_value_escaped( pwd, strlen( pwd )); + escaped = core_is_conn_opt_value_escaped( pwd, strnlen_s( pwd )); CHECK_CUSTOM_ERROR( !escaped, conn, SQLSRV_ERROR_UID_PWD_BRACES_NOT_ESCAPED ) { throw core::CoreException(); } - common_conn_str_append_func( ODBCConnOptions::PWD, pwd, strlen( pwd ), connection_string TSRMLS_CC ); + common_conn_str_append_func( ODBCConnOptions::PWD, pwd, strnlen_s( pwd ), connection_string TSRMLS_CC ); } } diff --git a/source/shared/core_sqlsrv.h b/source/shared/core_sqlsrv.h index 382b7ace..e5cf4565 100644 --- a/source/shared/core_sqlsrv.h +++ b/source/shared/core_sqlsrv.h @@ -56,10 +56,9 @@ // #define MultiByteToWideChar SystemLocale::ToUtf16 - - #define stricmp strcasecmp #define strnicmp strncasecmp +#define strnlen_s(s) strnlen_s(s, INT_MAX) #ifndef _WIN32 #define GetLastError() errno @@ -998,7 +997,7 @@ struct sqlsrv_encoding { bool not_for_connection; sqlsrv_encoding( _In_ const char* iana, _In_ unsigned int code_page, _In_ bool not_for_conn = false ): - iana( iana ), iana_len( strlen( iana )), code_page( code_page ), not_for_connection( not_for_conn ) + iana( iana ), iana_len( strnlen_s( iana )), code_page( code_page ), not_for_connection( not_for_conn ) { } }; @@ -1784,7 +1783,7 @@ inline bool call_error_handler( _Inout_ sqlsrv_context* ctx, _In_ unsigned long inline bool is_truncated_warning( _In_ SQLCHAR* state ) { #if defined(ZEND_DEBUG) - if( state == NULL || strlen( reinterpret_cast( state )) != 5 ) { \ + if( state == NULL || strnlen_s( reinterpret_cast( state )) != 5 ) { \ DIE( "Incorrect SQLSTATE given to is_truncated_warning." ); \ } #endif diff --git a/source/shared/localizationimpl.cpp b/source/shared/localizationimpl.cpp index eebef4fd..92221e98 100644 --- a/source/shared/localizationimpl.cpp +++ b/source/shared/localizationimpl.cpp @@ -310,7 +310,7 @@ SystemLocale::SystemLocale( const char * localeName ) charsetName = charsetName ? charsetName + 1 : localeName; for (const LocaleCP& lcp : lcpTable) { - if (!strncasecmp(lcp.localeName, charsetName, strlen(lcp.localeName))) + if (!strncasecmp(lcp.localeName, charsetName, strnlen_s(lcp.localeName))) { m_uAnsiCP = lcp.codePage; return; @@ -346,7 +346,7 @@ size_t SystemLocale::ToUtf16( UINT srcCodePage, const char * src, SSIZE_T cchSrc *pErrorCode = ERROR_INVALID_PARAMETER; return 0; } - size_t cchSrcActual = (cchSrc < 0 ? (1+strlen(src)) : cchSrc); + size_t cchSrcActual = (cchSrc < 0 ? (1+strnlen_s(src)) : cchSrc); bool hasLoss; return cvt.Convert( dest, cchDest, src, cchSrcActual, false, &hasLoss, pErrorCode ); } @@ -361,7 +361,7 @@ size_t SystemLocale::ToUtf16Strict( UINT srcCodePage, const char * src, SSIZE_T *pErrorCode = ERROR_INVALID_PARAMETER; return 0; } - size_t cchSrcActual = (cchSrc < 0 ? (1+strlen(src)) : cchSrc); + size_t cchSrcActual = (cchSrc < 0 ? (1+strnlen_s(src)) : cchSrc); bool hasLoss; return cvt.Convert( dest, cchDest, src, cchSrcActual, true, &hasLoss, pErrorCode ); }