diff --git a/winpr/include/winpr/stream.h b/winpr/include/winpr/stream.h index f351eaa15..ed637f034 100644 --- a/winpr/include/winpr/stream.h +++ b/winpr/include/winpr/stream.h @@ -27,6 +27,7 @@ #include #include #include +#include #ifdef __cplusplus extern "C" @@ -56,6 +57,19 @@ extern "C" WINPR_API void Stream_StaticInit(wStream* s, BYTE* buffer, size_t size); WINPR_API void Stream_Free(wStream* s, BOOL bFreeBuffer); +#define Stream_CheckAndLogRequiredLength(tag, s, len) \ + Stream_CheckAndLogRequiredLengthEx(tag, WLOG_WARN, s, len, "%s(%s:%d)", __FUNCTION__, \ + __FILE__, __LINE__) + WINPR_API BOOL Stream_CheckAndLogRequiredLengthEx(const char* tag, DWORD level, wStream* s, + UINT64 len, const char* fmt, ...); + WINPR_API BOOL Stream_CheckAndLogRequiredLengthExVa(const char* tag, DWORD level, wStream* s, + UINT64 len, const char* fmt, va_list args); + WINPR_API BOOL Stream_CheckAndLogRequiredLengthWLogEx(wLog* log, DWORD level, wStream* s, + UINT64 len, const char* fmt, ...); + WINPR_API BOOL Stream_CheckAndLogRequiredLengthWLogExVa(wLog* log, DWORD level, wStream* s, + UINT64 len, const char* fmt, + va_list args); + static INLINE void Stream_Seek(wStream* s, size_t _offset) { s->pointer += (_offset); diff --git a/winpr/libwinpr/utils/stream.c b/winpr/libwinpr/utils/stream.c index 1271981b7..cc119c771 100644 --- a/winpr/libwinpr/utils/stream.c +++ b/winpr/libwinpr/utils/stream.c @@ -132,3 +132,68 @@ void Stream_Free(wStream* s, BOOL bFreeBuffer) free(s); } } + +BOOL Stream_CheckAndLogRequiredLengthEx(const char* tag, DWORD level, wStream* s, UINT64 len, + const char* fmt, ...) +{ + const size_t actual = Stream_GetRemainingLength(s); + + if (actual < len) + { + va_list args; + + va_start(args, fmt); + Stream_CheckAndLogRequiredLengthExVa(tag, level, s, len, fmt, args); + va_end(args); + + return FALSE; + } + return TRUE; +} + +BOOL Stream_CheckAndLogRequiredLengthExVa(const char* tag, DWORD level, wStream* s, UINT64 len, + const char* fmt, va_list args) +{ + const size_t actual = Stream_GetRemainingLength(s); + + if (actual < len) + return Stream_CheckAndLogRequiredLengthWLogExVa(WLog_Get(tag), level, s, len, fmt, args); + return TRUE; +} + +BOOL Stream_CheckAndLogRequiredLengthWLogEx(wLog* log, DWORD level, wStream* s, UINT64 len, + const char* fmt, ...) +{ + const size_t actual = Stream_GetRemainingLength(s); + + if (actual < len) + { + va_list args; + + va_start(args, fmt); + Stream_CheckAndLogRequiredLengthWLogExVa(log, level, s, len, fmt, args); + va_end(args); + + return FALSE; + } + return TRUE; +} + +BOOL Stream_CheckAndLogRequiredLengthWLogExVa(wLog* log, DWORD level, wStream* s, UINT64 len, + const char* fmt, va_list args) +{ + const size_t actual = Stream_GetRemainingLength(s); + + if (actual < len) + { + char prefix[1024] = { 0 }; + + vsnprintf(prefix, sizeof(prefix), fmt, args); + + WLog_Print(log, level, "[%s] invalid length, got %" PRIuz ", require at least %" PRIu64, + prefix, actual, len); + winpr_log_backtrace_ex(log, level, 20); + return FALSE; + } + return TRUE; +}