1 module upromised.security; 2 version(hasSecurity): 3 import std.exception : enforce; 4 import std.format : format; 5 import upromised.promise : Promise, PromiseIterator; 6 import upromised.stream : Stream; 7 8 extern (C) { 9 enum OSStatus { 10 noErr = 0, 11 errSSLWouldBlock = -9803 12 } 13 14 enum SSLProtocolSide { 15 kSSLServerSide, 16 kSSLClientSide 17 } 18 19 enum SSLConnectionType { 20 kSSLStreamType, 21 kSSLDatagramType 22 } 23 24 alias SSLConnectionRef = TlsStream; 25 26 alias SSLReadFunc = OSStatus function(SSLConnectionRef connection, void *data, size_t* dataLength); 27 alias SSLWriteFunc = OSStatus function(SSLConnectionRef connection, const(void)* data, size_t* dataLength); 28 29 struct SSLContext; 30 SSLContext* SSLCreateContext(void* allocator, SSLProtocolSide protocolSide, SSLConnectionType connectionType) nothrow; 31 void CFRelease(SSLContext*) nothrow; 32 OSStatus SSLSetIOFuncs(SSLContext* context, SSLReadFunc readFunc, SSLWriteFunc writeFunc) nothrow; 33 OSStatus SSLSetConnection(SSLContext* context, SSLConnectionRef connection) nothrow; 34 OSStatus SSLSetPeerDomainName(SSLContext* context, const(char) *peerName, size_t peerNameLen) nothrow; 35 OSStatus SSLHandshake(SSLContext* context) nothrow; 36 OSStatus SSLClose(SSLContext* context) nothrow; 37 OSStatus SSLWrite(SSLContext* context, const(void)* data, size_t dataLength, size_t *processed) nothrow; 38 OSStatus SSLRead(SSLContext* context, const(void)* data, size_t dataLength, size_t *processed) nothrow; 39 } 40 41 class OSStatusError : Exception { 42 this(OSStatus status, string file = __FILE__, size_t line = __LINE__) { 43 this.status = status; 44 super("OSStatus(%s)".format(status), file, line); 45 } 46 47 OSStatus status; 48 } 49 50 private Promise!T readOne(T)(PromiseIterator!T read) nothrow { 51 T r; 52 return read.each((chunk) { 53 r = chunk; 54 return false; 55 }).then((_) => r); 56 } 57 58 class TlsStream : Stream { 59 private: 60 Stream underlying; 61 SSLContext* context; 62 ubyte[] readBuffer; 63 Promise!void pendingWrite; 64 65 OSStatus tryRead(void* dataArg, size_t* dataLength) nothrow { 66 ubyte[] data = (cast(ubyte*)dataArg)[0..*dataLength]; 67 68 if (readBuffer.length >= *dataLength) { 69 size_t n = *dataLength; 70 data[] = readBuffer[0..n]; 71 readBuffer = readBuffer[n..$]; 72 73 return OSStatus.noErr; 74 } 75 76 *dataLength = 0; 77 return OSStatus.errSSLWouldBlock; 78 } 79 80 OSStatus tryWrite(const(void)* dataArg, size_t* dataLength) nothrow { 81 if (pendingWrite !is null) { 82 *dataLength = 0; 83 return OSStatus.errSSLWouldBlock; 84 } 85 86 const(ubyte)[] data = (cast(const(ubyte)*)dataArg)[0..*dataLength]; 87 pendingWrite = underlying.write(data.idup); 88 return OSStatus.noErr; 89 } 90 91 Promise!OSStatus tryOperate(alias f, Args...)(size_t* operated, Args args) nothrow { 92 OSStatus r = f(args); 93 if (r == OSStatus.noErr || (operated !is null && *operated > 0)) { 94 return (pendingWrite is null ? Promise!void.resolved() : pendingWrite).then(() { 95 pendingWrite = null; 96 return r; 97 }); 98 } 99 100 if (r == OSStatus.errSSLWouldBlock) { 101 if (pendingWrite !is null) { 102 return pendingWrite.then(() { 103 pendingWrite = null; 104 return r; 105 }); 106 } 107 108 return underlying.read().readOne().then((chunk) { 109 readBuffer ~= chunk; 110 return r; 111 }); 112 } 113 114 return Promise!OSStatus.resolved(r); 115 } 116 117 Promise!OSStatus operate(alias f, Args...)(Args args) { 118 return tryOperate!f(null, args).then((r) { 119 if (r == OSStatus.errSSLWouldBlock) { 120 return operate!f(args); 121 } 122 123 return Promise!OSStatus.resolved(r); 124 }); 125 } 126 public: 127 this(Stream underlying) { 128 this.underlying = underlying; 129 context = SSLCreateContext(null, SSLProtocolSide.kSSLClientSide, SSLConnectionType.kSSLStreamType); 130 enforce(context !is null); 131 SSLSetIOFuncs(context, (self, a1, a2) => self.tryRead(a1, a2), (self, a1, a2) => self.tryWrite(a1, a2)); 132 SSLSetConnection(context, this); 133 } 134 135 ~this() { 136 if (context !is null) { 137 CFRelease(context); 138 } 139 } 140 141 Promise!void connect(string hostname = null) nothrow { 142 return Promise!void.resolved().then(() { 143 auto status = SSLSetPeerDomainName(context, hostname.ptr, hostname.length); 144 if (status != OSStatus.noErr) { 145 throw new OSStatusError(status); 146 } 147 return; 148 }).then(() => operate!SSLHandshake(context)).then((status) { 149 if (status != OSStatus.noErr) { 150 throw new OSStatusError(status); 151 } 152 }); 153 } 154 155 override Promise!void close() nothrow { 156 return underlying.close(); 157 } 158 override Promise!void shutdown() nothrow { 159 return operate!SSLClose(context).then((status) { 160 if (status != OSStatus.noErr) { 161 throw new OSStatusError(status); 162 } 163 }); 164 } 165 override Promise!void write(immutable(ubyte)[] data) nothrow { 166 if (data.length == 0) { 167 return Promise!void.resolved(); 168 } 169 170 size_t processed; 171 return tryOperate!SSLWrite(&processed, context, data.ptr, data.length, &processed).then((status) { 172 if (status != OSStatus.noErr && status != OSStatus.errSSLWouldBlock) { 173 throw new OSStatusError(status); 174 } 175 176 return write(data[processed..$]); 177 }); 178 } 179 180 override PromiseIterator!(const(ubyte)[]) read() nothrow { 181 return new class PromiseIterator!(const(ubyte)[]) { 182 override Promise!ItValue next(Promise!bool) { 183 return readOne().then((chunk) => chunk ? ItValue(false, chunk) : ItValue(true)); 184 } 185 }; 186 } 187 protected: 188 Promise!(const(ubyte)[]) readOne() nothrow { 189 const(ubyte)[] r; 190 return readOne(new ubyte[1024]) 191 .then((chunk) nothrow { 192 r = chunk; 193 }).except((OSStatusError e) { 194 if (e.status == -9805) { 195 } else { 196 throw e; 197 } 198 }).then(() => r); 199 } 200 201 Promise!(const(ubyte)[]) readOne(ubyte[] data) nothrow { 202 size_t processed; 203 return tryOperate!SSLRead(&processed, context, data.ptr, data.length, &processed).then((status) { 204 if (status != OSStatus.noErr && status != OSStatus.errSSLWouldBlock) { 205 throw new OSStatusError(status); 206 } 207 208 if (processed > 0) { 209 return Promise!(const(ubyte)[]).resolved(data[0..processed]); 210 } 211 212 return readOne(data); 213 }); 214 } 215 }