1 module upromised.tls; 2 version (hasOpenssl): 3 import deimos.openssl.ssl; 4 import upromised.stream : Stream; 5 import upromised.promise : DelegatePromise, Promise, PromiseIterator; 6 import upromised : fatal; 7 import std.exception : enforce; 8 import std.format : format; 9 10 shared static this() { 11 import deimos.openssl.err : ERR_load_crypto_strings; 12 13 SSL_library_init(); 14 ERR_load_crypto_strings(); 15 SSL_load_error_strings(); 16 } 17 18 private struct ExData(alias name, T) { 19 private shared static const(int) index; 20 shared static this() { 21 index = SSL_get_ex_new_index(0, null, null, null, null); 22 } 23 24 static set(SSL* ctx, T* data) { 25 int rc = SSL_set_ex_data(ctx, index, cast(void*)data); 26 assert(rc == 1); 27 } 28 29 static T* get(SSL* ctx) { 30 return cast(T*)SSL_get_ex_data(ctx, index); 31 } 32 } 33 alias HostnameExData = ExData!("hostname", const(char)); 34 35 private const(char)* ccopy(const(char)[] arg) nothrow { 36 import core.stdc.stdlib : malloc; 37 38 char[] r = (cast(char*)malloc(arg.length + 1))[0..arg.length + 1]; 39 r[0..$-1] = arg; 40 r[$-1] = 0; 41 return r.ptr; 42 } 43 44 private string[] alternativeNames(X509* x509) { 45 import deimos.openssl.objects : NID_subject_alt_name; 46 import deimos.openssl.x509v3 : GENERAL_NAME, GENERAL_NAMES, GENERAL_NAMES_free; 47 import std.algorithm : filter, map; 48 import std.array : array; 49 import std.range : iota; 50 51 auto names = cast(GENERAL_NAMES*)X509_get_ext_d2i(x509, NID_subject_alt_name, null, null); 52 scope(exit) GENERAL_NAMES_free(names); 53 return 0.iota(sk_GENERAL_NAME_num(names)) 54 .map!(i => sk_GENERAL_NAME_value(names, i)) 55 .filter!(gen => gen.type == GENERAL_NAME.GEN_DNS) 56 .map!(gen => (cast(const(char)*)gen.d.dNSName.data)[0..gen.d.dNSName.length].idup) 57 .array; 58 } 59 60 private string[] commonNames(X509* x509) { 61 import std.algorithm : map; 62 import std.array : array; 63 64 auto name = X509_get_subject_name(x509); 65 66 struct Indexes { 67 int front = -1; 68 void popFront() { 69 front = X509_NAME_get_index_by_NID(name, NID_commonName, front); 70 } 71 bool empty() { 72 return front < 0; 73 } 74 } 75 76 Indexes index; 77 index.popFront(); 78 79 return index 80 .map!(i => X509_NAME_get_entry(name, i)) 81 .map!(entry => X509_NAME_ENTRY_get_data(entry)) 82 .map!(common_name => (cast(const(char)*)common_name.data)[0..common_name.length].idup) 83 .array; 84 } 85 86 private bool matches(const(char)[] patternStr, const(char)[] hostnameStr) { 87 import std.string : split; 88 89 auto pattern = patternStr.split("."); 90 auto hostname = hostnameStr.split("."); 91 92 foreach_reverse(i; 0..hostname.length) { 93 if (pattern.length <= i) { 94 return false; 95 } 96 97 if (pattern[i] == "*") { 98 return true; 99 } 100 101 if (pattern[i] != hostname[i]) { 102 return false; 103 } 104 } 105 106 return true; 107 } 108 109 struct BioPair { 110 BIO* read_; 111 BIO* write_; 112 ubyte[] pending; 113 114 @disable this(this); 115 this(int) { 116 BIO_new_bio_pair(&read_, 0, &write_, 0); 117 enforce(read_ !is null); 118 enforce(write_ !is null); 119 } 120 121 ~this() nothrow { 122 if (!read_) return; 123 BIO_free(read_); 124 BIO_free(write_); 125 read_ = null; 126 write_ = null; 127 } 128 129 void write(const(ubyte)[] data) { 130 scope(success) pending ~= data; 131 if (pending.length > 0) { 132 writeSome(pending); 133 if (pending.length > 0) return; 134 } 135 136 writeSome(data); 137 } 138 139 bool flushPending() { 140 if (pending.length == 0) return false; 141 142 writeSome(pending); 143 return true; 144 } 145 146 private void writeSome(ref inout(ubyte)[] data) { 147 int r = BIO_write(write_, data.ptr, cast(int)data.length); 148 if (r < 0) throw new OpensslError(r, 0); 149 data = data[r..$]; 150 } 151 152 immutable(const(ubyte)[]) read() { 153 import std.exception : assumeUnique; 154 ubyte[] data = new ubyte[1024]; 155 int r = BIO_read(read_, data.ptr, cast(int)data.length); 156 if (r < 0) { 157 return null; 158 } 159 data.length = r; 160 return assumeUnique(data); 161 } 162 } 163 164 class OpensslError : Exception { 165 this(int ret, int err, string file = __FILE__, size_t line = __LINE__) { 166 import deimos.openssl.err : ERR_error_string, ERR_get_error; 167 import std.string : fromStringz; 168 169 string msg; 170 auto errNum = ERR_get_error(); 171 if (errNum != 0) { 172 msg = ERR_error_string(errNum, null).fromStringz.idup; 173 } 174 super("OpensslError ret=%s err=%s, msg %s".format(ret, err, msg), file, line); 175 } 176 } 177 178 class UnderlyingShutdown : Exception { 179 this() { 180 super("Underlying connection shutdown when expecting data"); 181 } 182 } 183 184 class TlsContext { 185 private: 186 SSL_CTX* ctx; 187 188 public: 189 this() { 190 ctx = SSL_CTX_new(SSLv23_client_method()); 191 enforce(ctx !is null); 192 } 193 194 this(string serverChainPath, string serverKeyPath) { 195 import std.string : toStringz; 196 197 ctx = SSL_CTX_new(SSLv23_server_method()); 198 enforce(ctx !is null); 199 int rc = SSL_CTX_use_certificate_chain_file(ctx, serverChainPath.toStringz); 200 if (rc <= 0) throw new OpensslError(rc, 0); 201 rc = SSL_CTX_use_PrivateKey_file(ctx, serverKeyPath.toStringz, SSL_FILETYPE_PEM); 202 if (rc <= 0) throw new OpensslError(rc, 0); 203 } 204 205 ~this() nothrow { 206 SSL_CTX_free(ctx); 207 } 208 209 void load_verify_locations(string cafile) { 210 import std.string : toStringz; 211 212 int rc = SSL_CTX_load_verify_locations(ctx, cafile.toStringz, null); 213 if (rc <= 0) { 214 throw new OpensslError(rc, 0); 215 } 216 } 217 } 218 219 class TlsStream : Stream { 220 private: 221 Stream underlying; 222 TlsContext ctx; 223 SSL* ssl; 224 BioPair tlsWrite; 225 BioPair tlsRead; 226 ubyte[] readBuffer; 227 228 enum Want : int { 229 Success = 0, 230 Read = -1, 231 Write = -2 232 } 233 234 Want tryOperate(alias a, Args...)(Args args) { 235 import deimos.openssl.err : ERR_clear_error; 236 237 ERR_clear_error(); 238 int ret = a(ssl, args); 239 if (ret < 0) { 240 int err = SSL_get_error(ssl, ret); 241 if (err == SSL_ERROR_WANT_READ) { 242 return Want.Read; 243 } else if (err == SSL_ERROR_WANT_WRITE) { 244 return Want.Write; 245 } else { 246 throw new OpensslError(ret, err); 247 } 248 } 249 return cast(Want)ret; 250 } 251 252 Promise!int operate(alias a, Args...)(Args args) nothrow { 253 import upromised.promise : break_, continue_, do_while; 254 255 int r; 256 return do_while(() { 257 auto want = tryOperate!a(args); 258 return do_while(() { 259 auto toWrite = tlsWrite.read(); 260 if (toWrite.length > 0) { 261 return underlying.write(toWrite) 262 .then(() => continue_); 263 } else { 264 return break_; 265 } 266 }).then(() { 267 if (want == Want.Read) { 268 if (tlsRead.flushPending()) { 269 return continue_; 270 } else { 271 return underlying.read().next().then((data) { 272 if (data.eof) { 273 throw new UnderlyingShutdown; 274 } 275 tlsRead.write(data.value); 276 return continue_; 277 }); 278 } 279 } else if (want >= Want.Success) { 280 r = want; 281 return break_; 282 } else { 283 return continue_; 284 } 285 }); 286 }).then(() => r); 287 } 288 289 public: 290 this(Stream stream, TlsContext ctx) { 291 underlying = stream; 292 this.ctx = ctx; 293 ssl = SSL_new(ctx.ctx); 294 enforce(ssl !is null); 295 tlsWrite = BioPair(0); 296 tlsRead = BioPair(0); 297 SSL_set_bio(ssl, tlsRead.read_, tlsWrite.write_); 298 readBuffer.length = 1024; 299 } 300 301 Promise!void accept() nothrow { 302 return operate!(SSL_accept).then((_) {}); 303 } 304 305 Promise!void connect(string hostname = null) nothrow { 306 import core.stdc.stdlib : free; 307 import std.algorithm : any, map; 308 import std.string : fromStringz; 309 310 extern(C) int function(int, X509_STORE_CTX*) verify; 311 312 if (hostname !is null) { 313 HostnameExData.set(this.ssl, hostname.ccopy); 314 verify = (preverified, ctx) { 315 if (preverified == 0) { 316 return 0; 317 } 318 319 if (X509_STORE_CTX_get_error_depth(ctx) != 0) { 320 return 1; 321 } 322 323 auto ssl = cast(SSL*) X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); 324 const(char)[] hostname = HostnameExData.get(ssl).fromStringz; 325 auto x509 = X509_STORE_CTX_get_current_cert(ctx); 326 327 if (x509.alternativeNames.map!((an) => an.matches(hostname)).any) { 328 return 1; 329 } 330 331 string[] commonNames = x509.commonNames; 332 if (commonNames.length > 0 && commonNames[$-1].matches(hostname)) { 333 return 1; 334 } 335 336 return 0; 337 }; 338 } 339 340 SSL_set_verify(this.ssl, SSL_VERIFY_PEER, verify); 341 return operate!(SSL_connect).then((_) {}).finall(() { 342 auto hostnameCopy = HostnameExData.get(this.ssl); 343 if (hostnameCopy !is null) { 344 free(cast(void*)hostnameCopy); 345 HostnameExData.set(this.ssl, null); 346 } 347 }); 348 } 349 350 override Promise!void write(immutable(ubyte)[] data) nothrow { 351 return operate!(SSL_write)(data.ptr, cast(int)data.length).then((a) { enforce(a == data.length);}); 352 } 353 354 override Promise!void shutdown() nothrow { 355 return operate!(SSL_shutdown,).then((a) => underlying.shutdown); 356 } 357 358 override PromiseIterator!(const(ubyte)[]) read() nothrow { 359 return new class PromiseIterator!(const(ubyte)[]) { 360 override Promise!ItValue next(Promise!bool) { 361 return readOne() 362 .then((chunk) => chunk.length > 0 ? ItValue(false, chunk) : ItValue(true)); 363 } 364 }; 365 } 366 367 override Promise!void close() nothrow { 368 return underlying.close(); 369 } 370 protected: 371 Promise!(const(ubyte)[]) readOne() nothrow { 372 import std.algorithm : swap; 373 374 const(ubyte)[] r; 375 return operate!(SSL_read)(readBuffer.ptr, cast(int)readBuffer.length).then((int read) { 376 r = readBuffer[0..read]; 377 }).except((Exception e) { 378 if ((cast(UnderlyingShutdown)e) !is null) { 379 } else { 380 throw e; 381 } 382 }).then(() => r); 383 } 384 } 385