Also throw on invalid WebSocket usage

This commit is contained in:
Alex Hultman 2019-01-18 16:57:54 +01:00
parent 90df6bbe3a
commit 1d4b874ae7
2 changed files with 51 additions and 40 deletions

View File

@ -3,7 +3,7 @@
#include "Utilities.h" #include "Utilities.h"
using namespace v8; using namespace v8;
/* uWS.App.ws('/pattern', options) */ /* uWS.App.ws('/pattern', behavior) */
template <typename APP> template <typename APP>
void uWS_App_ws(const FunctionCallbackInfo<Value> &args) { void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0); APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
@ -14,7 +14,8 @@ void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
return; return;
} }
// todo: small leak here, should be unique_ptrs moved in /* We don't need to care for these just yet, since we do
* not have a way to free the app itself */
Persistent<Function> *openPf = new Persistent<Function>(); Persistent<Function> *openPf = new Persistent<Function>();
Persistent<Function> *messagePf = new Persistent<Function>(); Persistent<Function> *messagePf = new Persistent<Function>();
Persistent<Function> *drainPf = new Persistent<Function>(); Persistent<Function> *drainPf = new Persistent<Function>();
@ -78,9 +79,8 @@ void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
PerSocketData *perSocketData = (PerSocketData *) ws->getUserData(); PerSocketData *perSocketData = (PerSocketData *) ws->getUserData();
Local<Value> argv[3] = {Local<Object>::New(isolate, *(perSocketData->socketPf)), Local<Value> argv[3] = {Local<Object>::New(isolate, *(perSocketData->socketPf)),
/*ArrayBuffer::New(isolate, (void *) message.data(), message.length())*/ messageArrayBuffer, messageArrayBuffer,
Boolean::New(isolate, opCode == uWS::OpCode::BINARY) Boolean::New(isolate, opCode == uWS::OpCode::BINARY)};
};
Local<Function>::New(isolate, *messagePf)->Call(isolate->GetCurrentContext()->Global(), 3, argv); Local<Function>::New(isolate, *messagePf)->Call(isolate->GetCurrentContext()->Global(), 3, argv);
/* Important: we clear the ArrayBuffer to make sure it is not invalidly used after return */ /* Important: we clear the ArrayBuffer to make sure it is not invalidly used after return */
@ -108,14 +108,18 @@ void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
HandleScope hs(isolate); HandleScope hs(isolate);
Local<ArrayBuffer> messageArrayBuffer = ArrayBuffer::New(isolate, (void *) message.data(), message.length()); Local<ArrayBuffer> messageArrayBuffer = ArrayBuffer::New(isolate, (void *) message.data(), message.length());
PerSocketData *perSocketData = (PerSocketData *) ws->getUserData(); PerSocketData *perSocketData = (PerSocketData *) ws->getUserData();
Local<Value> argv[3] = {Local<Object>::New(isolate, *(perSocketData->socketPf)), Local<Object> wsObject = Local<Object>::New(isolate, *(perSocketData->socketPf));
Local<Value> argv[3] = {wsObject,
Integer::New(isolate, code), Integer::New(isolate, code),
messageArrayBuffer messageArrayBuffer};
};
/* Invalidate this wsObject */
wsObject->SetAlignedPointerInInternalField(0, nullptr);
Local<Function>::New(isolate, *closePf)->Call(isolate->GetCurrentContext()->Global(), 3, argv); Local<Function>::New(isolate, *closePf)->Call(isolate->GetCurrentContext()->Global(), 3, argv);
delete perSocketData->socketPf;
/* Again, here we clear the buffer to avoid strange bugs */ /* Again, here we clear the buffer to avoid strange bugs */
messageArrayBuffer->Neuter(); messageArrayBuffer->Neuter();
}; };
@ -176,13 +180,6 @@ void uWS_App_listen(const FunctionCallbackInfo<Value> &args) {
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
/* This is very risky */
template <typename APP>
void uWS_App_free(const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
delete app;
}
template <typename APP> template <typename APP>
void uWS_App(const FunctionCallbackInfo<Value> &args) { void uWS_App(const FunctionCallbackInfo<Value> &args) {
Local<FunctionTemplate> appTemplate = FunctionTemplate::New(isolate); Local<FunctionTemplate> appTemplate = FunctionTemplate::New(isolate);
@ -294,7 +291,6 @@ void uWS_App(const FunctionCallbackInfo<Value> &args) {
/* ws, listen */ /* ws, listen */
appTemplate->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "ws"), FunctionTemplate::New(isolate, uWS_App_ws<APP>)); appTemplate->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "ws"), FunctionTemplate::New(isolate, uWS_App_ws<APP>));
appTemplate->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "listen"), FunctionTemplate::New(isolate, uWS_App_listen<APP>)); appTemplate->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "listen"), FunctionTemplate::New(isolate, uWS_App_listen<APP>));
// appTemplate->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "free"), FunctionTemplate::New(isolate, uWS_App_free<APP>));
Local<Object> localApp = appTemplate->GetFunction()->NewInstance(isolate->GetCurrentContext()).ToLocalChecked(); Local<Object> localApp = appTemplate->GetFunction()->NewInstance(isolate->GetCurrentContext()).ToLocalChecked();
localApp->SetAlignedPointerInInternalField(0, app); localApp->SetAlignedPointerInInternalField(0, app);

View File

@ -3,20 +3,29 @@
#include "Utilities.h" #include "Utilities.h"
using namespace v8; using namespace v8;
// todo: also check for use after free here! /* todo: probably isCorked, cork should be exposed? */
// todo: probably isCorked, cork should be exposed?
struct WebSocketWrapper { struct WebSocketWrapper {
static Persistent<Object> wsTemplate[2]; static Persistent<Object> wsTemplate[2];
template <bool SSL> template <bool SSL>
static inline uWS::WebSocket<SSL, true> *getWebSocket(const FunctionCallbackInfo<Value> &args) { static inline uWS::WebSocket<SSL, true> *getWebSocket(const FunctionCallbackInfo<Value> &args) {
return ((uWS::WebSocket<SSL, true> *) args.Holder()->GetAlignedPointerFromInternalField(0)); auto *ws = (uWS::WebSocket<SSL, true> *) args.Holder()->GetAlignedPointerFromInternalField(0);
if (!ws) {
args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Invalid access of closed uWS.WebSocket/SSLWebSocket.")));
}
return ws;
}
static inline void invalidateWsObject(const FunctionCallbackInfo<Value> &args) {
args.Holder()->SetAlignedPointerInInternalField(0, nullptr);
} }
/* Takes code, message, returns undefined */ /* Takes code, message, returns undefined */
template <bool SSL> template <bool SSL>
static void uWS_WebSocket_close(const FunctionCallbackInfo<Value> &args) { static void uWS_WebSocket_close(const FunctionCallbackInfo<Value> &args) {
auto *ws = getWebSocket<SSL>(args);
if (ws) {
int code = 0; int code = 0;
if (args.Length() >= 1) { if (args.Length() >= 1) {
code = args[0]->Uint32Value(); code = args[0]->Uint32Value();
@ -27,30 +36,36 @@ struct WebSocketWrapper {
return; return;
} }
getWebSocket<SSL>(args)->close(code, message.getString()); invalidateWsObject(args);
ws->close(code, message.getString());
}
} }
/* Takes nothing, returns integer */ /* Takes nothing, returns integer */
template <bool SSL> template <bool SSL>
static void uWS_WebSocket_getBufferedAmount(const FunctionCallbackInfo<Value> &args) { static void uWS_WebSocket_getBufferedAmount(const FunctionCallbackInfo<Value> &args) {
int bufferedAmount = getWebSocket<SSL>(args)->getBufferedAmount(); auto *ws = getWebSocket<SSL>(args);
if (ws) {
int bufferedAmount = ws->getBufferedAmount();
args.GetReturnValue().Set(Integer::New(isolate, bufferedAmount)); args.GetReturnValue().Set(Integer::New(isolate, bufferedAmount));
} }
}
/* Takes message, isBinary. Returns true on success, false otherwise */ /* Takes message, isBinary. Returns true on success, false otherwise */
template <bool SSL> template <bool SSL>
static void uWS_WebSocket_send(const FunctionCallbackInfo<Value> &args) { static void uWS_WebSocket_send(const FunctionCallbackInfo<Value> &args) {
auto *ws = getWebSocket<SSL>(args);
if (ws) {
NativeString message(args.GetIsolate(), args[0]); NativeString message(args.GetIsolate(), args[0]);
if (message.isInvalid(args)) { if (message.isInvalid(args)) {
return; return;
} }
bool isBinary = args[1]->BooleanValue(); bool ok = ws->send(message.getString(), args[1]->BooleanValue() ? uWS::OpCode::BINARY : uWS::OpCode::TEXT);
bool ok = getWebSocket<SSL>(args)->send(message.getString(), isBinary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT);
args.GetReturnValue().Set(Boolean::New(isolate, ok)); args.GetReturnValue().Set(Boolean::New(isolate, ok));
} }
}
template <bool SSL> template <bool SSL>
static void initWsTemplate() { static void initWsTemplate() {