1 module upromised.security;
2 version(hasSecurity):
3 import std.exception : enforce;
4 import std.format : format;
5 import upromised.promise : Promise, PromiseIterator;
6 import upromised.stream : Stream;
7 
8 extern (C) {
9 	enum OSStatus {
10 		noErr = 0,
11 		errSSLWouldBlock = -9803
12 	}
13 
14 	enum SSLProtocolSide {
15 		kSSLServerSide,
16 		kSSLClientSide
17 	}
18 	
19 	enum SSLConnectionType {
20 		kSSLStreamType,
21 		kSSLDatagramType
22 	}
23 
24 	alias SSLConnectionRef = TlsStream;
25 
26 	alias SSLReadFunc = OSStatus function(SSLConnectionRef connection, void *data, size_t* dataLength);
27 	alias SSLWriteFunc = OSStatus function(SSLConnectionRef connection, const(void)* data, size_t* dataLength);
28 
29 	struct SSLContext;
30 	SSLContext* SSLCreateContext(void* allocator, SSLProtocolSide protocolSide, SSLConnectionType connectionType) nothrow;
31 	void CFRelease(SSLContext*) nothrow;
32 	OSStatus SSLSetIOFuncs(SSLContext* context, SSLReadFunc readFunc, SSLWriteFunc writeFunc) nothrow;
33 	OSStatus SSLSetConnection(SSLContext* context, SSLConnectionRef connection) nothrow;
34 	OSStatus SSLSetPeerDomainName(SSLContext* context, const(char) *peerName, size_t peerNameLen) nothrow;
35 	OSStatus SSLHandshake(SSLContext* context) nothrow;
36 	OSStatus SSLClose(SSLContext* context) nothrow;
37 	OSStatus SSLWrite(SSLContext* context, const(void)* data, size_t dataLength, size_t *processed) nothrow;
38 	OSStatus SSLRead(SSLContext* context, const(void)* data, size_t dataLength, size_t *processed) nothrow;
39 }
40 
41 class OSStatusError : Exception {
42 	this(OSStatus status, string file = __FILE__, size_t line = __LINE__) {
43 		this.status = status;
44 		super("OSStatus(%s)".format(status), file, line);
45 	}
46 
47 	OSStatus status;
48 }
49 
50 private Promise!T readOne(T)(PromiseIterator!T read) nothrow {
51 	T r;
52 	return read.each((chunk) {
53 		r = chunk;
54 		return false;
55 	}).then((_) => r);
56 }
57 
58 class TlsStream : Stream {
59 private:
60 	Stream underlying;
61 	SSLContext* context;
62 	ubyte[] readBuffer;
63 	Promise!void pendingWrite;
64 
65 	OSStatus tryRead(void* dataArg, size_t* dataLength) nothrow {
66 		ubyte[] data = (cast(ubyte*)dataArg)[0..*dataLength];
67 		
68 		if (readBuffer.length >= *dataLength) {
69 			size_t n = *dataLength;
70 			data[] = readBuffer[0..n];
71 			readBuffer = readBuffer[n..$];
72 
73 			return OSStatus.noErr;
74 		}
75 		
76 		*dataLength = 0;
77 		return OSStatus.errSSLWouldBlock;
78 	}
79 
80 	OSStatus tryWrite(const(void)* dataArg, size_t* dataLength) nothrow {
81 		if (pendingWrite !is null) {
82 			*dataLength = 0;
83 			return OSStatus.errSSLWouldBlock;
84 		}
85 
86 		const(ubyte)[] data = (cast(const(ubyte)*)dataArg)[0..*dataLength];
87 		pendingWrite = underlying.write(data.idup);
88 		return OSStatus.noErr;
89 	}
90 
91 	Promise!OSStatus tryOperate(alias f, Args...)(size_t* operated, Args args) nothrow {
92 		OSStatus r = f(args);
93 		if (r == OSStatus.noErr || (operated !is null && *operated > 0)) {
94 			return (pendingWrite is null ? Promise!void.resolved() : pendingWrite).then(() {
95 				pendingWrite = null;
96 				return r;
97 			});
98 		}
99 		
100 		if (r == OSStatus.errSSLWouldBlock) {
101 			if (pendingWrite !is null) {
102 				return pendingWrite.then(() {
103 					pendingWrite = null;
104 					return r;
105 				});
106 			}
107 
108 			return underlying.read().readOne().then((chunk) {
109 				readBuffer ~= chunk;
110 				return r;
111 			});
112 		}
113 
114 		return Promise!OSStatus.resolved(r);
115 	}
116 
117 	Promise!OSStatus operate(alias f, Args...)(Args args) {
118 		return tryOperate!f(null, args).then((r) {
119 			if (r == OSStatus.errSSLWouldBlock) {
120 				return operate!f(args);
121 			}
122 			
123 			return Promise!OSStatus.resolved(r);
124 		});
125 	}
126 public:
127 	this(Stream underlying) {
128 		this.underlying = underlying;
129 		context = SSLCreateContext(null, SSLProtocolSide.kSSLClientSide, SSLConnectionType.kSSLStreamType);
130 		enforce(context !is null);
131 		SSLSetIOFuncs(context, (self, a1, a2) => self.tryRead(a1, a2), (self, a1, a2) => self.tryWrite(a1, a2));
132 		SSLSetConnection(context, this);
133 	}
134 
135 	~this() {
136 		if (context !is null) {
137 			CFRelease(context);
138 		}
139 	}
140 
141 	Promise!void connect(string hostname = null) nothrow {
142 		return Promise!void.resolved().then(() {
143 			auto status = SSLSetPeerDomainName(context, hostname.ptr, hostname.length);
144 			if (status != OSStatus.noErr) {
145 				throw new OSStatusError(status);
146 			}
147 			return;
148 		}).then(() => operate!SSLHandshake(context)).then((status) {
149 			if (status != OSStatus.noErr) {
150 				throw new OSStatusError(status);
151 			}
152 		});
153 	}
154 
155 	override Promise!void close() nothrow {
156 		return underlying.close();
157 	}
158 	override Promise!void shutdown() nothrow {
159 		return operate!SSLClose(context).then((status) {
160 			if (status != OSStatus.noErr) {
161 				throw new OSStatusError(status);
162 			}
163 		});
164 	}
165 	override Promise!void write(immutable(ubyte)[] data) nothrow {
166 		if (data.length == 0) {
167 			return Promise!void.resolved();
168 		}
169 		
170 		size_t processed;
171 		return tryOperate!SSLWrite(&processed, context, data.ptr, data.length, &processed).then((status) {
172 			if (status != OSStatus.noErr && status != OSStatus.errSSLWouldBlock) {
173 				throw new OSStatusError(status);
174 			}
175 
176 			return write(data[processed..$]);
177 		});
178 	}
179 
180 	override PromiseIterator!(const(ubyte)[]) read() nothrow {
181 		return new class PromiseIterator!(const(ubyte)[]) {
182 			override Promise!ItValue next(Promise!bool) {
183 				return readOne().then((chunk) => chunk ? ItValue(false, chunk) : ItValue(true));
184 			}
185 		};
186 	}
187 protected:
188 	Promise!(const(ubyte)[]) readOne() nothrow {
189 		const(ubyte)[] r;
190 		return readOne(new ubyte[1024])
191 		.then((chunk) nothrow {
192 			r = chunk;
193 		}).except((OSStatusError e) {
194 			if (e.status == -9805) {
195 			} else {
196 				throw e;
197 			}
198 		}).then(() => r);
199 	}
200 
201 	Promise!(const(ubyte)[]) readOne(ubyte[] data) nothrow {
202 		size_t processed;
203 		return tryOperate!SSLRead(&processed, context, data.ptr, data.length, &processed).then((status) {
204 			if (status != OSStatus.noErr && status != OSStatus.errSSLWouldBlock) {
205 				throw new OSStatusError(status);
206 			}
207 
208 			if (processed > 0) {
209 				return Promise!(const(ubyte)[]).resolved(data[0..processed]);
210 			}
211 
212 			return readOne(data);
213 		});
214 	}
215 }