diff --git a/src/AppWrapper.h b/src/AppWrapper.h index 66d8ba1..3402c51 100644 --- a/src/AppWrapper.h +++ b/src/AppWrapper.h @@ -1,5 +1,3 @@ -// test so that we pass Autobahn with compression/without compression with SSL/without SSL - #include "App.h" #include #include "Utilities.h" @@ -11,8 +9,10 @@ void uWS_App_ws(const FunctionCallbackInfo &args) { APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0); typename APP::WebSocketBehavior behavior = {}; - // pattern - NativeString nativeString(args.GetIsolate(), args[0]); + NativeString pattern(args.GetIsolate(), args[0]); + if (pattern.isInvalid(args)) { + return; + } // todo: small leak here, should be unique_ptrs moved in Persistent *openPf = new Persistent(); @@ -120,7 +120,7 @@ void uWS_App_ws(const FunctionCallbackInfo &args) { messageArrayBuffer->Neuter(); }; - app->template ws(std::string(nativeString.getData(), nativeString.getLength()), std::move(behavior)); + app->template ws(std::string(pattern.getString()), std::move(behavior)); /* Return this */ args.GetReturnValue().Set(args.Holder()); @@ -131,13 +131,16 @@ template void uWS_App_get(F f, const FunctionCallbackInfo &args) { APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0); - NativeString nativeString(args.GetIsolate(), args[0]); + NativeString pattern(args.GetIsolate(), args[0]); + if (pattern.isInvalid(args)) { + return; + } - // todo: make it UniquePersistent - Persistent *pf = new Persistent(); + /* todo: make it UniquePersistent */ + std::shared_ptr> pf(new Persistent); pf->Reset(args.GetIsolate(), Local::Cast(args[1])); - (app->*f)(std::string(nativeString.getData(), nativeString.getLength()), [pf](auto *res, auto *req) { + (app->*f)(std::string(pattern.getString()), [pf](auto *res, auto *req) { HandleScope hs(isolate); Local resObject = HttpResponseWrapper::getResInstance(); @@ -148,6 +151,12 @@ void uWS_App_get(F f, const FunctionCallbackInfo &args) { Local argv[] = {resObject, reqObject}; Local::New(isolate, *pf)->Call(isolate->GetCurrentContext()->Global(), 2, argv); + + /* Properly invalidate req */ + reqObject->SetAlignedPointerInInternalField(0, nullptr); + + /* µWS itself will terminate if not responded and not attached + * onAborted handler, so we can assume it's done */ }); args.GetReturnValue().Set(args.Holder()); @@ -164,10 +173,16 @@ void uWS_App_listen(const FunctionCallbackInfo &args) { Local::Cast(args[1])->Call(isolate->GetCurrentContext()->Global(), 1, argv); }); - // Return this args.GetReturnValue().Set(args.Holder()); } +/* This is very risky */ +template +void uWS_App_free(const FunctionCallbackInfo &args) { + APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0); + delete app; +} + template void uWS_App(const FunctionCallbackInfo &args) { Local appTemplate = FunctionTemplate::New(isolate); @@ -190,29 +205,41 @@ void uWS_App(const FunctionCallbackInfo &args) { if (args.Length() == 1) { /* Key file name */ NativeString keyFileNameValue(isolate, Local::Cast(args[0])->Get(String::NewFromUtf8(isolate, "key_file_name"))); - if (keyFileNameValue.getLength()) { - keyFileName.append(keyFileNameValue.getData(), keyFileNameValue.getLength()); + if (keyFileNameValue.isInvalid(args)) { + return; + } + if (keyFileNameValue.getString().length()) { + keyFileName.append(keyFileNameValue.getString()); ssl_options.key_file_name = keyFileName.c_str(); } /* Cert file name */ NativeString certFileNameValue(isolate, Local::Cast(args[0])->Get(String::NewFromUtf8(isolate, "cert_file_name"))); - if (certFileNameValue.getLength()) { - certFileName.append(certFileNameValue.getData(), certFileNameValue.getLength()); + if (certFileNameValue.isInvalid(args)) { + return; + } + if (certFileNameValue.getString().length()) { + certFileName.append(certFileNameValue.getString()); ssl_options.cert_file_name = certFileName.c_str(); } /* Passphrase */ NativeString passphraseValue(isolate, Local::Cast(args[0])->Get(String::NewFromUtf8(isolate, "passphrase"))); - if (passphraseValue.getLength()) { - passphrase.append(passphraseValue.getData(), passphraseValue.getLength()); + if (passphraseValue.isInvalid(args)) { + return; + } + if (passphraseValue.getString().length()) { + passphrase.append(passphraseValue.getString()); ssl_options.passphrase = passphrase.c_str(); } /* DH params file name */ NativeString dhParamsFileNameValue(isolate, Local::Cast(args[0])->Get(String::NewFromUtf8(isolate, "dh_params_file_name"))); - if (dhParamsFileNameValue.getLength()) { - dhParamsFileName.append(dhParamsFileNameValue.getData(), dhParamsFileNameValue.getLength()); + if (dhParamsFileNameValue.isInvalid(args)) { + return; + } + if (dhParamsFileNameValue.getString().length()) { + dhParamsFileName.append(dhParamsFileNameValue.getString()); ssl_options.dh_params_file_name = dhParamsFileName.c_str(); } } diff --git a/src/HttpRequestWrapper.h b/src/HttpRequestWrapper.h index ab3121c..6cb364f 100644 --- a/src/HttpRequestWrapper.h +++ b/src/HttpRequestWrapper.h @@ -3,57 +3,72 @@ #include "Utilities.h" using namespace v8; -// du behver inte klona dessa -// det finns bara en enda giltig request vid någon tid, och det är alltid -// inom en callback - -// håll en färdig request och tillåt functioner endast när du är inom callbacken - /* This one is the same for SSL and non-SSL */ struct HttpRequestWrapper { static Persistent reqTemplate; - // todo: refuse all function calls if we are not inside correct callback - static inline uWS::HttpRequest *getHttpRequest(const FunctionCallbackInfo &args) { - return ((uWS::HttpRequest *) args.Holder()->GetAlignedPointerFromInternalField(0)); + /* Thow on deleted request */ + auto *req = (uWS::HttpRequest *) args.Holder()->GetAlignedPointerFromInternalField(0); + if (!req) { + args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Using uWS.HttpRequest past its request handler return is forbidden (it is stack allocated)."))); + } + return req; } /* Takes int, returns string (must be in bounds) */ static void req_getParameter(const FunctionCallbackInfo &args) { - int index = args[0]->Uint32Value(); - std::string_view parameter = getHttpRequest(args)->getParameter(index); + auto *req = getHttpRequest(args); + if (req) { + int index = args[0]->Uint32Value(); + std::string_view parameter = req->getParameter(index); - args.GetReturnValue().Set(String::NewFromUtf8(isolate, parameter.data(), v8::String::kNormalString, parameter.length())); + args.GetReturnValue().Set(String::NewFromUtf8(isolate, parameter.data(), v8::String::kNormalString, parameter.length())); + } } /* Takes nothing, returns string */ static void req_getUrl(const FunctionCallbackInfo &args) { - std::string_view url = getHttpRequest(args)->getUrl(); + auto *req = getHttpRequest(args); + if (req) { + std::string_view url = req->getUrl(); - args.GetReturnValue().Set(String::NewFromUtf8(isolate, url.data(), v8::String::kNormalString, url.length())); + args.GetReturnValue().Set(String::NewFromUtf8(isolate, url.data(), v8::String::kNormalString, url.length())); + } } /* Takes String, returns String */ static void req_getHeader(const FunctionCallbackInfo &args) { - NativeString data(args.GetIsolate(), args[0]); - char *buf = data.getData(); int length = data.getLength(); + auto *req = getHttpRequest(args); + if (req) { + NativeString data(args.GetIsolate(), args[0]); + if (data.isInvalid(args)) { + return; + } - std::string_view header = getHttpRequest(args)->getHeader(std::string_view(buf, length)); + std::string_view header = req->getHeader(data.getString()); - args.GetReturnValue().Set(String::NewFromUtf8(isolate, header.data(), v8::String::kNormalString, header.length())); + args.GetReturnValue().Set(String::NewFromUtf8(isolate, header.data(), v8::String::kNormalString, header.length())); + } } + /* Takes nothing, returns string */ static void req_getMethod(const FunctionCallbackInfo &args) { - std::string_view method = getHttpRequest(args)->getMethod(); + auto *req = getHttpRequest(args); + if (req) { + std::string_view method = req->getMethod(); - args.GetReturnValue().Set(String::NewFromUtf8(isolate, method.data(), v8::String::kNormalString, method.length())); + args.GetReturnValue().Set(String::NewFromUtf8(isolate, method.data(), v8::String::kNormalString, method.length())); + } } static void req_getQuery(const FunctionCallbackInfo &args) { - std::string_view query = getHttpRequest(args)->getQuery(); + auto *req = getHttpRequest(args); + if (req) { + std::string_view query = req->getQuery(); - args.GetReturnValue().Set(String::NewFromUtf8(isolate, query.data(), v8::String::kNormalString, query.length())); + args.GetReturnValue().Set(String::NewFromUtf8(isolate, query.data(), v8::String::kNormalString, query.length())); + } } static void initReqTemplate() { @@ -62,7 +77,6 @@ struct HttpRequestWrapper { reqTemplateLocal->SetClassName(String::NewFromUtf8(isolate, "uWS.HttpRequest")); reqTemplateLocal->InstanceTemplate()->SetInternalFieldCount(1); - /* Register our functions */ reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getHeader"), FunctionTemplate::New(isolate, req_getHeader)); reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getParameter"), FunctionTemplate::New(isolate, req_getParameter)); @@ -76,8 +90,6 @@ struct HttpRequestWrapper { } static Local getReqInstance() { - // if we attach a number that counts up to this req we can check if the number is still valid when calling functions? - return Local::New(isolate, reqTemplate)->Clone(); } }; diff --git a/src/HttpResponseWrapper.h b/src/HttpResponseWrapper.h index 99b4383..32e8545 100644 --- a/src/HttpResponseWrapper.h +++ b/src/HttpResponseWrapper.h @@ -8,125 +8,194 @@ struct HttpResponseWrapper { template static inline uWS::HttpResponse *getHttpResponse(const FunctionCallbackInfo &args) { - return (uWS::HttpResponse *) args.Holder()->GetAlignedPointerFromInternalField(0); + auto *res = (uWS::HttpResponse *) args.Holder()->GetAlignedPointerFromInternalField(0); + if (!res) { + args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Invalid access of discarded (invalid, deleted) uWS.HttpResponse/SSLHttpResponse."))); + } + return res; + } + + /* Marks this JS object invalid */ + static inline void invalidateResObject(const FunctionCallbackInfo &args) { + args.Holder()->SetAlignedPointerInInternalField(0, nullptr); } /* Takes nothing, kills the connection */ template static void res_close(const FunctionCallbackInfo &args) { - getHttpResponse(args)->close(); - - args.GetReturnValue().Set(args.Holder()); + auto *res = getHttpResponse(args); + if (res) { + invalidateResObject(args); + res->close(); + args.GetReturnValue().Set(args.Holder()); + } } /* Takes function of data and isLast. Expects nothing from callback, returns this */ template static void res_onData(const FunctionCallbackInfo &args) { - /* This thing perfectly fits in with unique_function, and will Reset on destructor */ - UniquePersistent p(isolate, Local::Cast(args[0])); + auto *res = getHttpResponse(args); + if (res) { + /* This thing perfectly fits in with unique_function, and will Reset on destructor */ + UniquePersistent p(isolate, Local::Cast(args[0])); - getHttpResponse(args)->onData([p = std::move(p)](std::string_view data, bool last) { - HandleScope hs(isolate); + res->onData([p = std::move(p)](std::string_view data, bool last) { + HandleScope hs(isolate); - Local dataArrayBuffer = ArrayBuffer::New(isolate, (void *) data.data(), data.length()); + Local dataArrayBuffer = ArrayBuffer::New(isolate, (void *) data.data(), data.length()); - Local argv[] = {dataArrayBuffer, Boolean::New(isolate, last)}; - Local::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 2, argv); + Local argv[] = {dataArrayBuffer, Boolean::New(isolate, last)}; + Local::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 2, argv); - dataArrayBuffer->Neuter(); - }); + dataArrayBuffer->Neuter(); + }); - args.GetReturnValue().Set(args.Holder()); + args.GetReturnValue().Set(args.Holder()); + } } /* Takes nothing, returns nothing. Cb wants nothing returned. */ template static void res_onAborted(const FunctionCallbackInfo &args) { - /* This thing perfectly fits in with unique_function, and will Reset on destructor */ - UniquePersistent p(isolate, Local::Cast(args[0])); + auto *res = getHttpResponse(args); + if (res) { + /* This thing perfectly fits in with unique_function, and will Reset on destructor */ + UniquePersistent p(isolate, Local::Cast(args[0])); - getHttpResponse(args)->onAborted([p = std::move(p)]() { - HandleScope hs(isolate); + /* This is how we capture res (C++ this in invocation of this function) */ + UniquePersistent resObject(isolate, args.Holder()); - Local::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 0, nullptr); - }); + res->onAborted([p = std::move(p), resObject = std::move(resObject)]() { + HandleScope hs(isolate); - args.GetReturnValue().Set(args.Holder()); + /* Mark this resObject invalid */ + Local::New(isolate, resObject)->SetAlignedPointerInInternalField(0, nullptr); + + Local::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 0, nullptr); + }); + + args.GetReturnValue().Set(args.Holder()); + } } /* Returns the current write offset */ template static void res_getWriteOffset(const FunctionCallbackInfo &args) { - args.GetReturnValue().Set(Integer::New(isolate, getHttpResponse(args)->getWriteOffset())); + auto *res = getHttpResponse(args); + if (res) { + args.GetReturnValue().Set(Integer::New(isolate, getHttpResponse(args)->getWriteOffset())); + } } /* Takes function of bool(int), returns this */ template static void res_onWritable(const FunctionCallbackInfo &args) { - /* This thing perfectly fits in with unique_function, and will Reset on destructor */ - UniquePersistent p(isolate, Local::Cast(args[0])); + auto *res = getHttpResponse(args); + if (res) { + /* This thing perfectly fits in with unique_function, and will Reset on destructor */ + UniquePersistent p(isolate, Local::Cast(args[0])); - getHttpResponse(args)->onWritable([p = std::move(p)](int offset) { - HandleScope hs(isolate); + res->onWritable([p = std::move(p)](int offset) { + HandleScope hs(isolate); - Local argv[] = {Integer::NewFromUnsigned(isolate, offset)}; - return Local::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 1, argv)->BooleanValue(); - }); + Local argv[] = {Integer::NewFromUnsigned(isolate, offset)}; + return Local::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 1, argv)->BooleanValue(); + /* How important is this return? */ + }); - args.GetReturnValue().Set(args.Holder()); + args.GetReturnValue().Set(args.Holder()); + } } /* Takes string or arraybuffer, returns this */ template static void res_writeStatus(const FunctionCallbackInfo &args) { - NativeString data(args.GetIsolate(), args[0]); - getHttpResponse(args)->writeStatus(std::string_view(data.getData(), data.getLength())); + auto *res = getHttpResponse(args); + if (res) { + NativeString data(args.GetIsolate(), args[0]); + if (data.isInvalid(args)) { + return; + } + res->writeStatus(data.getString()); - args.GetReturnValue().Set(args.Holder()); + args.GetReturnValue().Set(args.Holder()); + } } /* Takes string or arraybuffer, returns this */ template static void res_end(const FunctionCallbackInfo &args) { - NativeString data(args.GetIsolate(), args[0]); - getHttpResponse(args)->end(std::string_view(data.getData(), data.getLength())); + auto *res = getHttpResponse(args); + if (res) { + NativeString data(args.GetIsolate(), args[0]); + if (data.isInvalid(args)) { + return; + } + invalidateResObject(args); + res->end(data.getString()); - args.GetReturnValue().Set(args.Holder()); + args.GetReturnValue().Set(args.Holder()); + } } /* Takes data and optionally totalLength, returns true for success, false for backpressure */ template static void res_tryEnd(const FunctionCallbackInfo &args) { - NativeString data(args.GetIsolate(), args[0]); + auto *res = getHttpResponse(args); + if (res) { + NativeString data(args.GetIsolate(), args[0]); + if (data.isInvalid(args)) { + return; + } - int totalSize = 0; - if (args.Length() > 1) { - totalSize = args[1]->Uint32Value(); + int totalSize = 0; + if (args.Length() > 1) { + totalSize = args[1]->Uint32Value(); + } + + bool ok = res->tryEnd(data.getString(), totalSize); + + /* Invalidate this object if we responded completely */ + if (res->hasResponded()) { + invalidateResObject(args); + } + + args.GetReturnValue().Set(Boolean::New(isolate, ok)); } - - bool ok = getHttpResponse(args)->tryEnd(std::string_view(data.getData(), data.getLength()), totalSize); - - args.GetReturnValue().Set(Boolean::New(isolate, ok)); } /* Takes data, returns true for success, false for backpressure */ template static void res_write(const FunctionCallbackInfo &args) { - NativeString data(args.GetIsolate(), args[0]); - bool ok = getHttpResponse(args)->write(std::string_view(data.getData(), data.getLength())); + auto *res = getHttpResponse(args); + if (res) { + NativeString data(args.GetIsolate(), args[0]); + if (data.isInvalid(args)) { + return; + } + bool ok = res->write(data.getString()); - args.GetReturnValue().Set(Boolean::New(isolate, ok)); + args.GetReturnValue().Set(Boolean::New(isolate, ok)); + } } /* Takes key, value. Returns this */ template static void res_writeHeader(const FunctionCallbackInfo &args) { - NativeString header(args.GetIsolate(), args[0]); - NativeString value(args.GetIsolate(), args[1]); - getHttpResponse(args)->writeHeader(std::string_view(header.getData(), header.getLength()), - std::string_view(value.getData(), value.getLength())); + auto *res = getHttpResponse(args); + if (res) { + NativeString header(args.GetIsolate(), args[0]); + if (header.isInvalid(args)) { + return; + } + NativeString value(args.GetIsolate(), args[1]); + if (value.isInvalid(args)) { + return; + } + res->writeHeader(header.getString(),value.getString()); - args.GetReturnValue().Set(args.Holder()); + args.GetReturnValue().Set(args.Holder()); + } } template diff --git a/src/Utilities.h b/src/Utilities.h index bd971f8..6de5e62 100644 --- a/src/Utilities.h +++ b/src/Utilities.h @@ -1,11 +1,15 @@ #ifndef ADDON_UTILITIES_H #define ADDON_UTILITIES_H +#include +using namespace v8; + class NativeString { char *data; size_t length; char utf8ValueMemory[sizeof(String::Utf8Value)]; String::Utf8Value *utf8Value = nullptr; + bool invalid = false; public: NativeString(Isolate *isolate, const Local &value) { if (value->IsUndefined()) { @@ -26,14 +30,21 @@ public: length = contents.ByteLength(); data = (char *) contents.Data(); } else { - static char empty[] = ""; - data = empty; - length = 0; + invalid = true; } } - char *getData() {return data;} - size_t getLength() {return length;} + bool isInvalid(const FunctionCallbackInfo &args) { + if (invalid) { + args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Text and data can only be passed by String, ArrayBuffer or TypedArray."))); + } + return invalid; + } + + std::string_view getString() { + return {data, length}; + } + ~NativeString() { if (utf8Value) { utf8Value->~Utf8Value(); diff --git a/src/WebSocketWrapper.h b/src/WebSocketWrapper.h index 837f0cc..4918237 100644 --- a/src/WebSocketWrapper.h +++ b/src/WebSocketWrapper.h @@ -3,6 +3,7 @@ #include "Utilities.h" using namespace v8; +// todo: also check for use after free here! // todo: probably isCorked, cork should be exposed? struct WebSocketWrapper { @@ -17,18 +18,16 @@ struct WebSocketWrapper { template static void uWS_WebSocket_close(const FunctionCallbackInfo &args) { int code = 0; - std::string_view message; - if (args.Length() >= 1) { code = args[0]->Uint32Value(); } - if (args.Length() >= 2) { - NativeString nativeString(args.GetIsolate(), args[1]); - message = {nativeString.getData(), nativeString.getLength()}; + NativeString message(args.GetIsolate(), args[1]); + if (message.isInvalid(args)) { + return; } - getWebSocket(args)->close(code, message); + getWebSocket(args)->close(code, message.getString()); } /* Takes nothing, returns integer */ @@ -41,13 +40,14 @@ struct WebSocketWrapper { /* Takes message, isBinary. Returns true on success, false otherwise */ template static void uWS_WebSocket_send(const FunctionCallbackInfo &args) { - NativeString nativeString(args.GetIsolate(), args[0]); + NativeString message(args.GetIsolate(), args[0]); + if (message.isInvalid(args)) { + return; + } bool isBinary = args[1]->BooleanValue(); - bool ok = getWebSocket(args)->send( - std::string_view(nativeString.getData(), nativeString.getLength()), isBinary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT - ); + bool ok = getWebSocket(args)->send(message.getString(), isBinary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT); args.GetReturnValue().Set(Boolean::New(isolate, ok)); } diff --git a/src/addon.cpp b/src/addon.cpp index b273f08..b3b9374 100644 --- a/src/addon.cpp +++ b/src/addon.cpp @@ -23,8 +23,6 @@ #include using namespace v8; -void emptyNextTickQueue(Isolate *isolate); - /* These two are definitely static */ std::vector> nextTickQueue; Isolate *isolate;