Throw exceptions on ill-use instead of segfaulting

This commit is contained in:
Alex Hultman 2019-01-18 16:33:56 +01:00
parent be8aaeeba7
commit d067388df4
6 changed files with 229 additions and 112 deletions

View File

@ -1,5 +1,3 @@
// test so that we pass Autobahn with compression/without compression with SSL/without SSL
#include "App.h"
#include <v8.h>
#include "Utilities.h"
@ -11,8 +9,10 @@ void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
typename APP::WebSocketBehavior behavior = {};
// pattern
NativeString nativeString(args.GetIsolate(), args[0]);
NativeString pattern(args.GetIsolate(), args[0]);
if (pattern.isInvalid(args)) {
return;
}
// todo: small leak here, should be unique_ptrs moved in
Persistent<Function> *openPf = new Persistent<Function>();
@ -120,7 +120,7 @@ void uWS_App_ws(const FunctionCallbackInfo<Value> &args) {
messageArrayBuffer->Neuter();
};
app->template ws<PerSocketData>(std::string(nativeString.getData(), nativeString.getLength()), std::move(behavior));
app->template ws<PerSocketData>(std::string(pattern.getString()), std::move(behavior));
/* Return this */
args.GetReturnValue().Set(args.Holder());
@ -131,13 +131,16 @@ template <typename APP, typename F>
void uWS_App_get(F f, const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
NativeString nativeString(args.GetIsolate(), args[0]);
NativeString pattern(args.GetIsolate(), args[0]);
if (pattern.isInvalid(args)) {
return;
}
// todo: make it UniquePersistent
Persistent<Function> *pf = new Persistent<Function>();
/* todo: make it UniquePersistent */
std::shared_ptr<Persistent<Function>> pf(new Persistent<Function>);
pf->Reset(args.GetIsolate(), Local<Function>::Cast(args[1]));
(app->*f)(std::string(nativeString.getData(), nativeString.getLength()), [pf](auto *res, auto *req) {
(app->*f)(std::string(pattern.getString()), [pf](auto *res, auto *req) {
HandleScope hs(isolate);
Local<Object> resObject = HttpResponseWrapper::getResInstance<APP>();
@ -148,6 +151,12 @@ void uWS_App_get(F f, const FunctionCallbackInfo<Value> &args) {
Local<Value> argv[] = {resObject, reqObject};
Local<Function>::New(isolate, *pf)->Call(isolate->GetCurrentContext()->Global(), 2, argv);
/* Properly invalidate req */
reqObject->SetAlignedPointerInInternalField(0, nullptr);
/* µWS itself will terminate if not responded and not attached
* onAborted handler, so we can assume it's done */
});
args.GetReturnValue().Set(args.Holder());
@ -164,10 +173,16 @@ void uWS_App_listen(const FunctionCallbackInfo<Value> &args) {
Local<Function>::Cast(args[1])->Call(isolate->GetCurrentContext()->Global(), 1, argv);
});
// Return this
args.GetReturnValue().Set(args.Holder());
}
/* This is very risky */
template <typename APP>
void uWS_App_free(const FunctionCallbackInfo<Value> &args) {
APP *app = (APP *) args.Holder()->GetAlignedPointerFromInternalField(0);
delete app;
}
template <typename APP>
void uWS_App(const FunctionCallbackInfo<Value> &args) {
Local<FunctionTemplate> appTemplate = FunctionTemplate::New(isolate);
@ -190,29 +205,41 @@ void uWS_App(const FunctionCallbackInfo<Value> &args) {
if (args.Length() == 1) {
/* Key file name */
NativeString keyFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "key_file_name")));
if (keyFileNameValue.getLength()) {
keyFileName.append(keyFileNameValue.getData(), keyFileNameValue.getLength());
if (keyFileNameValue.isInvalid(args)) {
return;
}
if (keyFileNameValue.getString().length()) {
keyFileName.append(keyFileNameValue.getString());
ssl_options.key_file_name = keyFileName.c_str();
}
/* Cert file name */
NativeString certFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "cert_file_name")));
if (certFileNameValue.getLength()) {
certFileName.append(certFileNameValue.getData(), certFileNameValue.getLength());
if (certFileNameValue.isInvalid(args)) {
return;
}
if (certFileNameValue.getString().length()) {
certFileName.append(certFileNameValue.getString());
ssl_options.cert_file_name = certFileName.c_str();
}
/* Passphrase */
NativeString passphraseValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "passphrase")));
if (passphraseValue.getLength()) {
passphrase.append(passphraseValue.getData(), passphraseValue.getLength());
if (passphraseValue.isInvalid(args)) {
return;
}
if (passphraseValue.getString().length()) {
passphrase.append(passphraseValue.getString());
ssl_options.passphrase = passphrase.c_str();
}
/* DH params file name */
NativeString dhParamsFileNameValue(isolate, Local<Object>::Cast(args[0])->Get(String::NewFromUtf8(isolate, "dh_params_file_name")));
if (dhParamsFileNameValue.getLength()) {
dhParamsFileName.append(dhParamsFileNameValue.getData(), dhParamsFileNameValue.getLength());
if (dhParamsFileNameValue.isInvalid(args)) {
return;
}
if (dhParamsFileNameValue.getString().length()) {
dhParamsFileName.append(dhParamsFileNameValue.getString());
ssl_options.dh_params_file_name = dhParamsFileName.c_str();
}
}

View File

@ -3,57 +3,72 @@
#include "Utilities.h"
using namespace v8;
// du behver inte klona dessa
// det finns bara en enda giltig request vid någon tid, och det är alltid
// inom en callback
// håll en färdig request och tillåt functioner endast när du är inom callbacken
/* This one is the same for SSL and non-SSL */
struct HttpRequestWrapper {
static Persistent<Object> reqTemplate;
// todo: refuse all function calls if we are not inside correct callback
static inline uWS::HttpRequest *getHttpRequest(const FunctionCallbackInfo<Value> &args) {
return ((uWS::HttpRequest *) args.Holder()->GetAlignedPointerFromInternalField(0));
/* Thow on deleted request */
auto *req = (uWS::HttpRequest *) args.Holder()->GetAlignedPointerFromInternalField(0);
if (!req) {
args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Using uWS.HttpRequest past its request handler return is forbidden (it is stack allocated).")));
}
return req;
}
/* Takes int, returns string (must be in bounds) */
static void req_getParameter(const FunctionCallbackInfo<Value> &args) {
int index = args[0]->Uint32Value();
std::string_view parameter = getHttpRequest(args)->getParameter(index);
auto *req = getHttpRequest(args);
if (req) {
int index = args[0]->Uint32Value();
std::string_view parameter = req->getParameter(index);
args.GetReturnValue().Set(String::NewFromUtf8(isolate, parameter.data(), v8::String::kNormalString, parameter.length()));
args.GetReturnValue().Set(String::NewFromUtf8(isolate, parameter.data(), v8::String::kNormalString, parameter.length()));
}
}
/* Takes nothing, returns string */
static void req_getUrl(const FunctionCallbackInfo<Value> &args) {
std::string_view url = getHttpRequest(args)->getUrl();
auto *req = getHttpRequest(args);
if (req) {
std::string_view url = req->getUrl();
args.GetReturnValue().Set(String::NewFromUtf8(isolate, url.data(), v8::String::kNormalString, url.length()));
args.GetReturnValue().Set(String::NewFromUtf8(isolate, url.data(), v8::String::kNormalString, url.length()));
}
}
/* Takes String, returns String */
static void req_getHeader(const FunctionCallbackInfo<Value> &args) {
NativeString data(args.GetIsolate(), args[0]);
char *buf = data.getData(); int length = data.getLength();
auto *req = getHttpRequest(args);
if (req) {
NativeString data(args.GetIsolate(), args[0]);
if (data.isInvalid(args)) {
return;
}
std::string_view header = getHttpRequest(args)->getHeader(std::string_view(buf, length));
std::string_view header = req->getHeader(data.getString());
args.GetReturnValue().Set(String::NewFromUtf8(isolate, header.data(), v8::String::kNormalString, header.length()));
args.GetReturnValue().Set(String::NewFromUtf8(isolate, header.data(), v8::String::kNormalString, header.length()));
}
}
/* Takes nothing, returns string */
static void req_getMethod(const FunctionCallbackInfo<Value> &args) {
std::string_view method = getHttpRequest(args)->getMethod();
auto *req = getHttpRequest(args);
if (req) {
std::string_view method = req->getMethod();
args.GetReturnValue().Set(String::NewFromUtf8(isolate, method.data(), v8::String::kNormalString, method.length()));
args.GetReturnValue().Set(String::NewFromUtf8(isolate, method.data(), v8::String::kNormalString, method.length()));
}
}
static void req_getQuery(const FunctionCallbackInfo<Value> &args) {
std::string_view query = getHttpRequest(args)->getQuery();
auto *req = getHttpRequest(args);
if (req) {
std::string_view query = req->getQuery();
args.GetReturnValue().Set(String::NewFromUtf8(isolate, query.data(), v8::String::kNormalString, query.length()));
args.GetReturnValue().Set(String::NewFromUtf8(isolate, query.data(), v8::String::kNormalString, query.length()));
}
}
static void initReqTemplate() {
@ -62,7 +77,6 @@ struct HttpRequestWrapper {
reqTemplateLocal->SetClassName(String::NewFromUtf8(isolate, "uWS.HttpRequest"));
reqTemplateLocal->InstanceTemplate()->SetInternalFieldCount(1);
/* Register our functions */
reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getHeader"), FunctionTemplate::New(isolate, req_getHeader));
reqTemplateLocal->PrototypeTemplate()->Set(String::NewFromUtf8(isolate, "getParameter"), FunctionTemplate::New(isolate, req_getParameter));
@ -76,8 +90,6 @@ struct HttpRequestWrapper {
}
static Local<Object> getReqInstance() {
// if we attach a number that counts up to this req we can check if the number is still valid when calling functions?
return Local<Object>::New(isolate, reqTemplate)->Clone();
}
};

View File

@ -8,125 +8,194 @@ struct HttpResponseWrapper {
template <bool SSL>
static inline uWS::HttpResponse<SSL> *getHttpResponse(const FunctionCallbackInfo<Value> &args) {
return (uWS::HttpResponse<SSL> *) args.Holder()->GetAlignedPointerFromInternalField(0);
auto *res = (uWS::HttpResponse<SSL> *) args.Holder()->GetAlignedPointerFromInternalField(0);
if (!res) {
args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Invalid access of discarded (invalid, deleted) uWS.HttpResponse/SSLHttpResponse.")));
}
return res;
}
/* Marks this JS object invalid */
static inline void invalidateResObject(const FunctionCallbackInfo<Value> &args) {
args.Holder()->SetAlignedPointerInInternalField(0, nullptr);
}
/* Takes nothing, kills the connection */
template <bool SSL>
static void res_close(const FunctionCallbackInfo<Value> &args) {
getHttpResponse<SSL>(args)->close();
args.GetReturnValue().Set(args.Holder());
auto *res = getHttpResponse<SSL>(args);
if (res) {
invalidateResObject(args);
res->close();
args.GetReturnValue().Set(args.Holder());
}
}
/* Takes function of data and isLast. Expects nothing from callback, returns this */
template <bool SSL>
static void res_onData(const FunctionCallbackInfo<Value> &args) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
auto *res = getHttpResponse<SSL>(args);
if (res) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
getHttpResponse<SSL>(args)->onData([p = std::move(p)](std::string_view data, bool last) {
HandleScope hs(isolate);
res->onData([p = std::move(p)](std::string_view data, bool last) {
HandleScope hs(isolate);
Local<ArrayBuffer> dataArrayBuffer = ArrayBuffer::New(isolate, (void *) data.data(), data.length());
Local<ArrayBuffer> dataArrayBuffer = ArrayBuffer::New(isolate, (void *) data.data(), data.length());
Local<Value> argv[] = {dataArrayBuffer, Boolean::New(isolate, last)};
Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 2, argv);
Local<Value> argv[] = {dataArrayBuffer, Boolean::New(isolate, last)};
Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 2, argv);
dataArrayBuffer->Neuter();
});
dataArrayBuffer->Neuter();
});
args.GetReturnValue().Set(args.Holder());
args.GetReturnValue().Set(args.Holder());
}
}
/* Takes nothing, returns nothing. Cb wants nothing returned. */
template <bool SSL>
static void res_onAborted(const FunctionCallbackInfo<Value> &args) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
auto *res = getHttpResponse<SSL>(args);
if (res) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
getHttpResponse<SSL>(args)->onAborted([p = std::move(p)]() {
HandleScope hs(isolate);
/* This is how we capture res (C++ this in invocation of this function) */
UniquePersistent<Object> resObject(isolate, args.Holder());
Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 0, nullptr);
});
res->onAborted([p = std::move(p), resObject = std::move(resObject)]() {
HandleScope hs(isolate);
args.GetReturnValue().Set(args.Holder());
/* Mark this resObject invalid */
Local<Object>::New(isolate, resObject)->SetAlignedPointerInInternalField(0, nullptr);
Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 0, nullptr);
});
args.GetReturnValue().Set(args.Holder());
}
}
/* Returns the current write offset */
template <bool SSL>
static void res_getWriteOffset(const FunctionCallbackInfo<Value> &args) {
args.GetReturnValue().Set(Integer::New(isolate, getHttpResponse<SSL>(args)->getWriteOffset()));
auto *res = getHttpResponse<SSL>(args);
if (res) {
args.GetReturnValue().Set(Integer::New(isolate, getHttpResponse<SSL>(args)->getWriteOffset()));
}
}
/* Takes function of bool(int), returns this */
template <bool SSL>
static void res_onWritable(const FunctionCallbackInfo<Value> &args) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
auto *res = getHttpResponse<SSL>(args);
if (res) {
/* This thing perfectly fits in with unique_function, and will Reset on destructor */
UniquePersistent<Function> p(isolate, Local<Function>::Cast(args[0]));
getHttpResponse<SSL>(args)->onWritable([p = std::move(p)](int offset) {
HandleScope hs(isolate);
res->onWritable([p = std::move(p)](int offset) {
HandleScope hs(isolate);
Local<Value> argv[] = {Integer::NewFromUnsigned(isolate, offset)};
return Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 1, argv)->BooleanValue();
});
Local<Value> argv[] = {Integer::NewFromUnsigned(isolate, offset)};
return Local<Function>::New(isolate, p)->Call(isolate->GetCurrentContext()->Global(), 1, argv)->BooleanValue();
/* How important is this return? */
});
args.GetReturnValue().Set(args.Holder());
args.GetReturnValue().Set(args.Holder());
}
}
/* Takes string or arraybuffer, returns this */
template <bool SSL>
static void res_writeStatus(const FunctionCallbackInfo<Value> &args) {
NativeString data(args.GetIsolate(), args[0]);
getHttpResponse<SSL>(args)->writeStatus(std::string_view(data.getData(), data.getLength()));
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]);
if (data.isInvalid(args)) {
return;
}
res->writeStatus(data.getString());
args.GetReturnValue().Set(args.Holder());
args.GetReturnValue().Set(args.Holder());
}
}
/* Takes string or arraybuffer, returns this */
template <bool SSL>
static void res_end(const FunctionCallbackInfo<Value> &args) {
NativeString data(args.GetIsolate(), args[0]);
getHttpResponse<SSL>(args)->end(std::string_view(data.getData(), data.getLength()));
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]);
if (data.isInvalid(args)) {
return;
}
invalidateResObject(args);
res->end(data.getString());
args.GetReturnValue().Set(args.Holder());
args.GetReturnValue().Set(args.Holder());
}
}
/* Takes data and optionally totalLength, returns true for success, false for backpressure */
template <bool SSL>
static void res_tryEnd(const FunctionCallbackInfo<Value> &args) {
NativeString data(args.GetIsolate(), args[0]);
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]);
if (data.isInvalid(args)) {
return;
}
int totalSize = 0;
if (args.Length() > 1) {
totalSize = args[1]->Uint32Value();
int totalSize = 0;
if (args.Length() > 1) {
totalSize = args[1]->Uint32Value();
}
bool ok = res->tryEnd(data.getString(), totalSize);
/* Invalidate this object if we responded completely */
if (res->hasResponded()) {
invalidateResObject(args);
}
args.GetReturnValue().Set(Boolean::New(isolate, ok));
}
bool ok = getHttpResponse<SSL>(args)->tryEnd(std::string_view(data.getData(), data.getLength()), totalSize);
args.GetReturnValue().Set(Boolean::New(isolate, ok));
}
/* Takes data, returns true for success, false for backpressure */
template <bool SSL>
static void res_write(const FunctionCallbackInfo<Value> &args) {
NativeString data(args.GetIsolate(), args[0]);
bool ok = getHttpResponse<SSL>(args)->write(std::string_view(data.getData(), data.getLength()));
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString data(args.GetIsolate(), args[0]);
if (data.isInvalid(args)) {
return;
}
bool ok = res->write(data.getString());
args.GetReturnValue().Set(Boolean::New(isolate, ok));
args.GetReturnValue().Set(Boolean::New(isolate, ok));
}
}
/* Takes key, value. Returns this */
template <bool SSL>
static void res_writeHeader(const FunctionCallbackInfo<Value> &args) {
NativeString header(args.GetIsolate(), args[0]);
NativeString value(args.GetIsolate(), args[1]);
getHttpResponse<SSL>(args)->writeHeader(std::string_view(header.getData(), header.getLength()),
std::string_view(value.getData(), value.getLength()));
auto *res = getHttpResponse<SSL>(args);
if (res) {
NativeString header(args.GetIsolate(), args[0]);
if (header.isInvalid(args)) {
return;
}
NativeString value(args.GetIsolate(), args[1]);
if (value.isInvalid(args)) {
return;
}
res->writeHeader(header.getString(),value.getString());
args.GetReturnValue().Set(args.Holder());
args.GetReturnValue().Set(args.Holder());
}
}
template <bool SSL>

View File

@ -1,11 +1,15 @@
#ifndef ADDON_UTILITIES_H
#define ADDON_UTILITIES_H
#include <v8.h>
using namespace v8;
class NativeString {
char *data;
size_t length;
char utf8ValueMemory[sizeof(String::Utf8Value)];
String::Utf8Value *utf8Value = nullptr;
bool invalid = false;
public:
NativeString(Isolate *isolate, const Local<Value> &value) {
if (value->IsUndefined()) {
@ -26,14 +30,21 @@ public:
length = contents.ByteLength();
data = (char *) contents.Data();
} else {
static char empty[] = "";
data = empty;
length = 0;
invalid = true;
}
}
char *getData() {return data;}
size_t getLength() {return length;}
bool isInvalid(const FunctionCallbackInfo<Value> &args) {
if (invalid) {
args.GetReturnValue().Set(isolate->ThrowException(String::NewFromUtf8(isolate, "Text and data can only be passed by String, ArrayBuffer or TypedArray.")));
}
return invalid;
}
std::string_view getString() {
return {data, length};
}
~NativeString() {
if (utf8Value) {
utf8Value->~Utf8Value();

View File

@ -3,6 +3,7 @@
#include "Utilities.h"
using namespace v8;
// todo: also check for use after free here!
// todo: probably isCorked, cork should be exposed?
struct WebSocketWrapper {
@ -17,18 +18,16 @@ struct WebSocketWrapper {
template <bool SSL>
static void uWS_WebSocket_close(const FunctionCallbackInfo<Value> &args) {
int code = 0;
std::string_view message;
if (args.Length() >= 1) {
code = args[0]->Uint32Value();
}
if (args.Length() >= 2) {
NativeString nativeString(args.GetIsolate(), args[1]);
message = {nativeString.getData(), nativeString.getLength()};
NativeString message(args.GetIsolate(), args[1]);
if (message.isInvalid(args)) {
return;
}
getWebSocket<SSL>(args)->close(code, message);
getWebSocket<SSL>(args)->close(code, message.getString());
}
/* Takes nothing, returns integer */
@ -41,13 +40,14 @@ struct WebSocketWrapper {
/* Takes message, isBinary. Returns true on success, false otherwise */
template <bool SSL>
static void uWS_WebSocket_send(const FunctionCallbackInfo<Value> &args) {
NativeString nativeString(args.GetIsolate(), args[0]);
NativeString message(args.GetIsolate(), args[0]);
if (message.isInvalid(args)) {
return;
}
bool isBinary = args[1]->BooleanValue();
bool ok = getWebSocket<SSL>(args)->send(
std::string_view(nativeString.getData(), nativeString.getLength()), isBinary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT
);
bool ok = getWebSocket<SSL>(args)->send(message.getString(), isBinary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT);
args.GetReturnValue().Set(Boolean::New(isolate, ok));
}

View File

@ -23,8 +23,6 @@
#include <type_traits>
using namespace v8;
void emptyNextTickQueue(Isolate *isolate);
/* These two are definitely static */
std::vector<UniquePersistent<Function>> nextTickQueue;
Isolate *isolate;