Throw exceptions on ill-use instead of segfaulting

This commit is contained in:
Alex Hultman 2019-01-18 16:33:56 +01:00
parent be8aaeeba7
commit d067388df4
6 changed files with 229 additions and 112 deletions

View File

@ -1,5 +1,3 @@
// test so that we pass Autobahn with compression/without compression with SSL/without SSL
#include "App.h" #include "App.h"
#include <v8.h> #include <v8.h>
#include "Utilities.h" #include "Utilities.h"
@ -11,8 +9,10 @@ void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0); APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
typename APP::WebSocketBehavior behavior = {}; typename APP::WebSocketBehavior behavior = {};
// pattern NativeString pattern(args.GetIsolate(), args[0]);
NativeString nativeString(args.GetIsolate(), args[0]); if (pattern.isInvalid(args)) {
return;
}
// todo: small leak here, should be unique_ptrs moved in // todo: small leak here, should be unique_ptrs moved in
Persistent<Function> *openPf = new Persistent<Function>(); Persistent<Function> *openPf = new Persistent<Function>();
@ -120,7 +120,7 @@ void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
messageArrayBuffer->Neuter(); messageArrayBuffer->Neuter();
}; };
app->template ws<PerSocketData>(std::string(nativeString.getData(), nativeString.getLength()), std::move(behavior)); app->template ws<PerSocketData>(std::string(pattern.getString()), std::move(behavior));
/* Return this */ /* Return this */
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
@ -131,13 +131,16 @@ template <typename APP, typename F>
void uWS_App_get(F f, const FunctionCallbackInfo<Value> &args) { void uWS_App_get(F f, const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0); APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
NativeString nativeString(args.GetIsolate(), args[0]); NativeString pattern(args.GetIsolate(), args[0]);
if (pattern.isInvalid(args)) {
return;
}
// todo: make it UniquePersistent /* todo: make it UniquePersistent */
Persistent<Function> *pf = new Persistent<Function>(); std::shared_ptr<Persistent<Function>> pf(new Persistent<Function>);
pf->Reset(args.GetIsolate(), Local<Function>::Cast(args[1])); pf->Reset(args.GetIsolate(), Local<Function>::Cast(args[1]));
(app->*f)(std::string(nativeString.getData(), nativeString.getLength()), [pf](auto *res, auto *req) { (app->*f)(std::string(pattern.getString()), [pf](auto *res, auto *req) {
HandleScope hs(isolate); HandleScope hs(isolate);
Local<Object> resObject = HttpResponseWrapper::getResInstance<APP>(); Local<Object> resObject = HttpResponseWrapper::getResInstance<APP>();
@ -148,6 +151,12 @@ void uWS_App_get(F f, const FunctionCallbackInfo<Value> &args) {
Local<Value> argv[] = {resObject, reqObject}; Local<Value> argv[] = {resObject, reqObject};
Local<Function>::New(isolate, *pf)->Call(isolate->GetCurrentContext()->Global(), 2, argv); Local<Function>::New(isolate, *pf)->Call(isolate->GetCurrentContext()->Global(), 2, argv);
/* Properly invalidate req */
reqObject->SetAlignedPointerInInternalField(0, nullptr);
/* µWS itself will terminate if not responded and not attached
* onAborted handler, so we can assume it's done */
}); });
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
@ -164,10 +173,16 @@ void uWS_App_listen(const FunctionCallbackInfo<Value> &args) {
Local<Function>::Cast(args[1])->Call(isolate->GetCurrentContext()->Global(), 1, argv); Local<Function>::Cast(args[1])->Call(isolate->GetCurrentContext()->Global(), 1, argv);
}); });
// Return this
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
/* This is very risky */
template <typename APP>
void uWS_App_free(const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
delete app;
}
template <typename APP> template <typename APP>
void uWS_App(const FunctionCallbackInfo<Value> &args) { void uWS_App(const FunctionCallbackInfo<Value> &args) {
Local<FunctionTemplate> appTemplate = FunctionTemplate::New(isolate); Local<FunctionTemplate> appTemplate = FunctionTemplate::New(isolate);
@ -190,29 +205,41 @@ void uWS_App(const FunctionCallbackInfo<Value> &args) {
if (args.Length() == 1) { if (args.Length() == 1) {
/* Key file name */ /* Key file name */
NativeString keyFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "key_file_name"))); NativeString keyFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "key_file_name")));
if (keyFileNameValue.getLength()) { if (keyFileNameValue.isInvalid(args)) {
keyFileName.append(keyFileNameValue.getData(), keyFileNameValue.getLength()); return;
}
if (keyFileNameValue.getString().length()) {
keyFileName.append(keyFileNameValue.getString());
ssl_options.key_file_name = keyFileName.c_str(); ssl_options.key_file_name = keyFileName.c_str();
} }
/* Cert file name */ /* Cert file name */
NativeString certFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "cert_file_name"))); NativeString certFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "cert_file_name")));
if (certFileNameValue.getLength()) { if (certFileNameValue.isInvalid(args)) {
certFileName.append(certFileNameValue.getData(), certFileNameValue.getLength()); return;
}
if (certFileNameValue.getString().length()) {
certFileName.append(certFileNameValue.getString());
ssl_options.cert_file_name = certFileName.c_str(); ssl_options.cert_file_name = certFileName.c_str();
} }
/* Passphrase */ /* Passphrase */
NativeString passphraseValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "passphrase"))); NativeString passphraseValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "passphrase")));
if (passphraseValue.getLength()) { if (passphraseValue.isInvalid(args)) {
passphrase.append(passphraseValue.getData(), passphraseValue.getLength()); return;
}
if (passphraseValue.getString().length()) {
passphrase.append(passphraseValue.getString());
ssl_options.passphrase = passphrase.c_str(); ssl_options.passphrase = passphrase.c_str();
} }
/* DH params file name */ /* DH params file name */
NativeString dhParamsFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "dh_params_file_name"))); NativeString dhParamsFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "dh_params_file_name")));
if (dhParamsFileNameValue.getLength()) { if (dhParamsFileNameValue.isInvalid(args)) {
dhParamsFileName.append(dhParamsFileNameValue.getData(), dhParamsFileNameValue.getLength()); return;
}
if (dhParamsFileNameValue.getString().length()) {
dhParamsFileName.append(dhParamsFileNameValue.getString());
ssl_options.dh_params_file_name = dhParamsFileName.c_str(); ssl_options.dh_params_file_name = dhParamsFileName.c_str();
} }
} }

View File

@ -3,58 +3,73 @@
#include "Utilities.h" #include "Utilities.h"
using namespace v8; using namespace v8;
// du behver inte klona dessa
// det finns bara en enda giltig request vid någon tid, och det är alltid
// inom en callback
// håll en färdig request och tillåt functioner endast när du är inom callbacken
/* This one is the same for SSL and non-SSL */ /* This one is the same for SSL and non-SSL */
struct HttpRequestWrapper { struct HttpRequestWrapper {
static Persistent<Object> reqTemplate; static Persistent<Object> reqTemplate;
// todo: refuse all function calls if we are not inside correct callback
static inline uWS::HttpRequest *getHttpRequest(const FunctionCallbackInfo<Value> &args) { static inline uWS::HttpRequest *getHttpRequest(const FunctionCallbackInfo<Value> &args) {
return ((uWS::HttpRequest *) args.Holder()->GetAlignedPointerFromInternalField(0)); /* Thow on deleted request */
auto *req = (uWS::HttpRequest *) args.Holder()->GetAlignedPointerFromInternalField(0);
if (!req) {
args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Using uWS.HttpRequest past its request handler return is forbidden (it is stack allocated).")));
}
return req;
} }
/* Takes int, returns string (must be in bounds) */ /* Takes int, returns string (must be in bounds) */
static void req_getParameter(const FunctionCallbackInfo<Value> &args) { static void req_getParameter(const FunctionCallbackInfo<Value> &args) {
auto *req = getHttpRequest(args);
if (req) {
int index = args[0]->Uint32Value(); int index = args[0]->Uint32Value();
std::string_view parameter = getHttpRequest(args)->getParameter(index); std::string_view parameter = req->getParameter(index);
args.GetReturnValue().Set(String::NewFromUtf8(isolate, parameter.data(), v8::String::kNormalString, parameter.length())); args.GetReturnValue().Set(String::NewFromUtf8(isolate, parameter.data(), v8::String::kNormalString, parameter.length()));
} }
}
/* Takes nothing, returns string */ /* Takes nothing, returns string */
static void req_getUrl(const FunctionCallbackInfo<Value> &args) { static void req_getUrl(const FunctionCallbackInfo<Value> &args) {
std::string_view url = getHttpRequest(args)->getUrl(); auto *req = getHttpRequest(args);
if (req) {
std::string_view url = req->getUrl();
args.GetReturnValue().Set(String::NewFromUtf8(isolate, url.data(), v8::String::kNormalString, url.length())); args.GetReturnValue().Set(String::NewFromUtf8(isolate, url.data(), v8::String::kNormalString, url.length()));
} }
}
/* Takes String, returns String */ /* Takes String, returns String */
static void req_getHeader(const FunctionCallbackInfo<Value> &args) { static void req_getHeader(const FunctionCallbackInfo<Value> &args) {
auto *req = getHttpRequest(args);
if (req) {
NativeString data(args.GetIsolate(), args[0]); NativeString data(args.GetIsolate(), args[0]);
char *buf = data.getData(); int length = data.getLength(); if (data.isInvalid(args)) {
return;
}
std::string_view header = getHttpRequest(args)->getHeader(std::string_view(buf, length)); std::string_view header = req->getHeader(data.getString());
args.GetReturnValue().Set(String::NewFromUtf8(isolate, header.data(), v8::String::kNormalString, header.length())); args.GetReturnValue().Set(String::NewFromUtf8(isolate, header.data(), v8::String::kNormalString, header.length()));
} }
}
/* Takes nothing, returns string */
static void req_getMethod(const FunctionCallbackInfo<Value> &args) { static void req_getMethod(const FunctionCallbackInfo<Value> &args) {
std::string_view method = getHttpRequest(args)->getMethod(); auto *req = getHttpRequest(args);
if (req) {
std::string_view method = req->getMethod();
args.GetReturnValue().Set(String::NewFromUtf8(isolate, method.data(), v8::String::kNormalString, method.length())); args.GetReturnValue().Set(String::NewFromUtf8(isolate, method.data(), v8::String::kNormalString, method.length()));
} }
}
static void req_getQuery(const FunctionCallbackInfo<Value> &args) { static void req_getQuery(const FunctionCallbackInfo<Value> &args) {
std::string_view query = getHttpRequest(args)->getQuery(); auto *req = getHttpRequest(args);
if (req) {
std::string_view query = req->getQuery();
args.GetReturnValue().Set(String::NewFromUtf8(isolate, query.data(), v8::String::kNormalString, query.length())); args.GetReturnValue().Set(String::NewFromUtf8(isolate, query.data(), v8::String::kNormalString, query.length()));
} }
}
static void initReqTemplate() { static void initReqTemplate() {
/* We do clone every request object, we could share them, they are illegal to use outside the function anyways */ /* We do clone every request object, we could share them, they are illegal to use outside the function anyways */
@ -62,7 +77,6 @@ struct HttpRequestWrapper {
reqTemplateLocal->SetClassName(String::NewFromUtf8(isolate, "uWS.HttpRequest")); reqTemplateLocal->SetClassName(String::NewFromUtf8(isolate, "uWS.HttpRequest"));
reqTemplateLocal->InstanceTemplate()->SetInternalFieldCount(1); reqTemplateLocal->InstanceTemplate()->SetInternalFieldCount(1);
/* Register our functions */ /* Register our functions */
reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getHeader"), FunctionTemplate::New(isolate, req_getHeader)); reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getHeader"), FunctionTemplate::New(isolate, req_getHeader));
reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getParameter"), FunctionTemplate::New(isolate, req_getParameter)); reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getParameter"), FunctionTemplate::New(isolate, req_getParameter));
@ -76,8 +90,6 @@ struct HttpRequestWrapper {
} }
static Local<Object> getReqInstance() { static Local<Object> getReqInstance() {
// if we attach a number that counts up to this req we can check if the number is still valid when calling functions?
return Local<Object>::New(isolate, reqTemplate)->Clone(); return Local<Object>::New(isolate, reqTemplate)->Clone();
} }
}; };

View File

@ -8,24 +8,38 @@ struct HttpResponseWrapper {
template <bool SSL> template <bool SSL>
static inline uWS::HttpResponse<SSL> *getHttpResponse(const FunctionCallbackInfo<Value> &args) { static inline uWS::HttpResponse<SSL> *getHttpResponse(const FunctionCallbackInfo<Value> &args) {
return (uWS::HttpResponse<SSL> *) args.Holder()->GetAlignedPointerFromInternalField(0); auto *res = (uWS::HttpResponse<SSL> *) args.Holder()->GetAlignedPointerFromInternalField(0);
if (!res) {
args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Invalid access of discarded (invalid, deleted) uWS.HttpResponse/SSLHttpResponse.")));
}
return res;
}
/* Marks this JS object invalid */
static inline void invalidateResObject(const FunctionCallbackInfo<Value> &args) {
args.Holder()->SetAlignedPointerInInternalField(0, nullptr);
} }
/* Takes nothing, kills the connection */ /* Takes nothing, kills the connection */
template <bool SSL> template <bool SSL>
static void res_close(const FunctionCallbackInfo<Value> &args) { static void res_close(const FunctionCallbackInfo<Value> &args) {
getHttpResponse<SSL>(args)->close(); auto *res = getHttpResponse<SSL>(args);
if (res) {
invalidateResObject(args);
res->close();
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
}
/* Takes function of data and isLast. Expects nothing from callback, returns this */ /* Takes function of data and isLast. Expects nothing from callback, returns this */
template <bool SSL> template <bool SSL>
static void res_onData(const FunctionCallbackInfo<Value> &args) { static void res_onData(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */ /* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0])); UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
getHttpResponse<SSL>(args)->onData([p = std::move(p)](std::string_view data, bool last) { res->onData([p = std::move(p)](std::string_view data, bool last) {
HandleScope hs(isolate); HandleScope hs(isolate);
Local<ArrayBuffer> dataArrayBuffer = ArrayBuffer::New(isolate, (void *) data.data(), data.length()); Local<ArrayBuffer> dataArrayBuffer = ArrayBuffer::New(isolate, (void *) data.data(), data.length());
@ -38,96 +52,151 @@ struct HttpResponseWrapper {
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
}
/* Takes nothing, returns nothing. Cb wants nothing returned. */ /* Takes nothing, returns nothing. Cb wants nothing returned. */
template <bool SSL> template <bool SSL>
static void res_onAborted(const FunctionCallbackInfo<Value> &args) { static void res_onAborted(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */ /* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0])); UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
getHttpResponse<SSL>(args)->onAborted([p = std::move(p)]() { /* This is how we capture res (C++ this in invocation of this function) */
UniquePersistent<Object> resObject(isolate, args.Holder());
res->onAborted([p = std::move(p), resObject = std::move(resObject)]() {
HandleScope hs(isolate); HandleScope hs(isolate);
/* Mark this resObject invalid */
Local<Object>::New(isolate, resObject)->SetAlignedPointerInInternalField(0, nullptr);
Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 0, nullptr); Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 0, nullptr);
}); });
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
}
/* Returns the current write offset */ /* Returns the current write offset */
template <bool SSL> template <bool SSL>
static void res_getWriteOffset(const FunctionCallbackInfo<Value> &args) { static void res_getWriteOffset(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
args.GetReturnValue().Set(Integer::New(isolate, getHttpResponse<SSL>(args)->getWriteOffset())); args.GetReturnValue().Set(Integer::New(isolate, getHttpResponse<SSL>(args)->getWriteOffset()));
} }
}
/* Takes function of bool(int), returns this */ /* Takes function of bool(int), returns this */
template <bool SSL> template <bool SSL>
static void res_onWritable(const FunctionCallbackInfo<Value> &args) { static void res_onWritable(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */ /* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0])); UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
getHttpResponse<SSL>(args)->onWritable([p = std::move(p)](int offset) { res->onWritable([p = std::move(p)](int offset) {
HandleScope hs(isolate); HandleScope hs(isolate);
Local<Value> argv[] = {Integer::NewFromUnsigned(isolate, offset)}; Local<Value> argv[] = {Integer::NewFromUnsigned(isolate, offset)};
return Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 1, argv)->BooleanValue(); return Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 1, argv)->BooleanValue();
/* How important is this return? */
}); });
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
}
/* Takes string or arraybuffer, returns this */ /* Takes string or arraybuffer, returns this */
template <bool SSL> template <bool SSL>
static void res_writeStatus(const FunctionCallbackInfo<Value> &args) { static void res_writeStatus(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]); NativeString data(args.GetIsolate(), args[0]);
getHttpResponse<SSL>(args)->writeStatus(std::string_view(data.getData(), data.getLength())); if (data.isInvalid(args)) {
return;
}
res->writeStatus(data.getString());
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
}
/* Takes string or arraybuffer, returns this */ /* Takes string or arraybuffer, returns this */
template <bool SSL> template <bool SSL>
static void res_end(const FunctionCallbackInfo<Value> &args) { static void res_end(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]); NativeString data(args.GetIsolate(), args[0]);
getHttpResponse<SSL>(args)->end(std::string_view(data.getData(), data.getLength())); if (data.isInvalid(args)) {
return;
}
invalidateResObject(args);
res->end(data.getString());
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
}
/* Takes data and optionally totalLength, returns true for success, false for backpressure */ /* Takes data and optionally totalLength, returns true for success, false for backpressure */
template <bool SSL> template <bool SSL>
static void res_tryEnd(const FunctionCallbackInfo<Value> &args) { static void res_tryEnd(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]); NativeString data(args.GetIsolate(), args[0]);
if (data.isInvalid(args)) {
return;
}
int totalSize = 0; int totalSize = 0;
if (args.Length() > 1) { if (args.Length() > 1) {
totalSize = args[1]->Uint32Value(); totalSize = args[1]->Uint32Value();
} }
bool ok = getHttpResponse<SSL>(args)->tryEnd(std::string_view(data.getData(), data.getLength()), totalSize); bool ok = res->tryEnd(data.getString(), totalSize);
/* Invalidate this object if we responded completely */
if (res->hasResponded()) {
invalidateResObject(args);
}
args.GetReturnValue().Set(Boolean::New(isolate, ok)); args.GetReturnValue().Set(Boolean::New(isolate, ok));
} }
}
/* Takes data, returns true for success, false for backpressure */ /* Takes data, returns true for success, false for backpressure */
template <bool SSL> template <bool SSL>
static void res_write(const FunctionCallbackInfo<Value> &args) { static void res_write(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]); NativeString data(args.GetIsolate(), args[0]);
bool ok = getHttpResponse<SSL>(args)->write(std::string_view(data.getData(), data.getLength())); if (data.isInvalid(args)) {
return;
}
bool ok = res->write(data.getString());
args.GetReturnValue().Set(Boolean::New(isolate, ok)); args.GetReturnValue().Set(Boolean::New(isolate, ok));
} }
}
/* Takes key, value. Returns this */ /* Takes key, value. Returns this */
template <bool SSL> template <bool SSL>
static void res_writeHeader(const FunctionCallbackInfo<Value> &args) { static void res_writeHeader(const FunctionCallbackInfo<Value> &args) {
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString header(args.GetIsolate(), args[0]); NativeString header(args.GetIsolate(), args[0]);
if (header.isInvalid(args)) {
return;
}
NativeString value(args.GetIsolate(), args[1]); NativeString value(args.GetIsolate(), args[1]);
getHttpResponse<SSL>(args)->writeHeader(std::string_view(header.getData(), header.getLength()), if (value.isInvalid(args)) {
std::string_view(value.getData(), value.getLength())); return;
}
res->writeHeader(header.getString(),value.getString());
args.GetReturnValue().Set(args.Holder()); args.GetReturnValue().Set(args.Holder());
} }
}
template <bool SSL> template <bool SSL>
static void initResTemplate() { static void initResTemplate() {

View File

@ -1,11 +1,15 @@
#ifndef ADDON_UTILITIES_H #ifndef ADDON_UTILITIES_H
#define ADDON_UTILITIES_H #define ADDON_UTILITIES_H
#include <v8.h>
using namespace v8;
class NativeString { class NativeString {
char *data; char *data;
size_t length; size_t length;
char utf8ValueMemory[sizeof(String::Utf8Value)]; char utf8ValueMemory[sizeof(String::Utf8Value)];
String::Utf8Value *utf8Value = nullptr; String::Utf8Value *utf8Value = nullptr;
bool invalid = false;
public: public:
NativeString(Isolate *isolate, const Local<Value> &value) { NativeString(Isolate *isolate, const Local<Value> &value) {
if (value->IsUndefined()) { if (value->IsUndefined()) {
@ -26,14 +30,21 @@ public:
length = contents.ByteLength(); length = contents.ByteLength();
data = (char *) contents.Data(); data = (char *) contents.Data();
} else { } else {
static char empty[] = ""; invalid = true;
data = empty;
length = 0;
} }
} }
char *getData() {return data;} bool isInvalid(const FunctionCallbackInfo<Value> &args) {
size_t getLength() {return length;} if (invalid) {
args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Text and data can only be passed by String, ArrayBuffer or TypedArray.")));
}
return invalid;
}
std::string_view getString() {
return {data, length};
}
~NativeString() { ~NativeString() {
if (utf8Value) { if (utf8Value) {
utf8Value->~Utf8Value(); utf8Value->~Utf8Value();

View File

@ -3,6 +3,7 @@
#include "Utilities.h" #include "Utilities.h"
using namespace v8; using namespace v8;
// todo: also check for use after free here!
// todo: probably isCorked, cork should be exposed? // todo: probably isCorked, cork should be exposed?
struct WebSocketWrapper { struct WebSocketWrapper {
@ -17,18 +18,16 @@ struct WebSocketWrapper {
template <bool SSL> template <bool SSL>
static void uWS_WebSocket_close(const FunctionCallbackInfo<Value> &args) { static void uWS_WebSocket_close(const FunctionCallbackInfo<Value> &args) {
int code = 0; int code = 0;
std::string_view message;
if (args.Length() >= 1) { if (args.Length() >= 1) {
code = args[0]->Uint32Value(); code = args[0]->Uint32Value();
} }
if (args.Length() >= 2) { NativeString message(args.GetIsolate(), args[1]);
NativeString nativeString(args.GetIsolate(), args[1]); if (message.isInvalid(args)) {
message = {nativeString.getData(), nativeString.getLength()}; return;
} }
getWebSocket<SSL>(args)->close(code, message); getWebSocket<SSL>(args)->close(code, message.getString());
} }
/* Takes nothing, returns integer */ /* Takes nothing, returns integer */
@ -41,13 +40,14 @@ struct WebSocketWrapper {
/* Takes message, isBinary. Returns true on success, false otherwise */ /* Takes message, isBinary. Returns true on success, false otherwise */
template <bool SSL> template <bool SSL>
static void uWS_WebSocket_send(const FunctionCallbackInfo<Value> &args) { static void uWS_WebSocket_send(const FunctionCallbackInfo<Value> &args) {
NativeString nativeString(args.GetIsolate(), args[0]); NativeString message(args.GetIsolate(), args[0]);
if (message.isInvalid(args)) {
return;
}
bool isBinary = args[1]->BooleanValue(); bool isBinary = args[1]->BooleanValue();
bool ok = getWebSocket<SSL>(args)->send( bool ok = getWebSocket<SSL>(args)->send(message.getString(), isBinary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT);
std::string_view(nativeString.getData(), nativeString.getLength()), isBinary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT
);
args.GetReturnValue().Set(Boolean::New(isolate, ok)); args.GetReturnValue().Set(Boolean::New(isolate, ok));
} }

View File

@ -23,8 +23,6 @@
#include <type_traits> #include <type_traits>
using namespace v8; using namespace v8;
void emptyNextTickQueue(Isolate *isolate);
/* These two are definitely static */ /* These two are definitely static */
std::vector<UniquePersistent<Function>> nextTickQueue; std::vector<UniquePersistent<Function>> nextTickQueue;
Isolate *isolate; Isolate *isolate;