1 module upromised.tls;
2 version (hasOpenssl):
3 import deimos.openssl.ssl;
4 import upromised.stream : Stream;
5 import upromised.promise : DelegatePromise, Promise, PromiseIterator;
6 import upromised : fatal;
7 import std.exception : enforce;
8 import std.format : format;
9 
10 shared static this() {
11     import deimos.openssl.err : ERR_load_crypto_strings;
12 
13     SSL_library_init();
14     ERR_load_crypto_strings();
15     SSL_load_error_strings();
16 }
17 
18 private struct ExData(alias name, T) {
19     private shared static const(int) index;
20     shared static this() {
21         index = SSL_get_ex_new_index(0, null, null, null, null);
22     }
23 
24     static set(SSL* ctx, T* data) {
25         int rc = SSL_set_ex_data(ctx, index, cast(void*)data);
26         assert(rc == 1);
27     }
28 
29     static T* get(SSL* ctx) {
30         return cast(T*)SSL_get_ex_data(ctx, index);
31     }
32 }
33 alias HostnameExData = ExData!("hostname", const(char));
34 
35 private const(char)* ccopy(const(char)[] arg) nothrow {
36     import core.stdc.stdlib : malloc;
37 
38     char[] r = (cast(char*)malloc(arg.length + 1))[0..arg.length + 1];
39     r[0..$-1] = arg;
40     r[$-1] = 0;
41     return r.ptr;
42 }
43 
44 private string[] alternativeNames(X509* x509) {
45     import deimos.openssl.objects : NID_subject_alt_name;
46     import deimos.openssl.x509v3 : GENERAL_NAME, GENERAL_NAMES, GENERAL_NAMES_free;
47     import std.algorithm : filter, map;
48     import std.array : array;
49     import std.range : iota;
50     
51     auto names = cast(GENERAL_NAMES*)X509_get_ext_d2i(x509, NID_subject_alt_name, null, null);
52     scope(exit) GENERAL_NAMES_free(names);
53     return 0.iota(sk_GENERAL_NAME_num(names))
54         .map!(i => sk_GENERAL_NAME_value(names, i))
55         .filter!(gen => gen.type == GENERAL_NAME.GEN_DNS)
56         .map!(gen => (cast(const(char)*)gen.d.dNSName.data)[0..gen.d.dNSName.length].idup)
57         .array;
58 }
59 
60 private string[] commonNames(X509* x509) {
61     import std.algorithm : map;
62     import std.array : array;
63 
64     auto name = X509_get_subject_name(x509);
65 
66     struct Indexes {
67         int front = -1;
68         void popFront() {
69             front = X509_NAME_get_index_by_NID(name, NID_commonName, front);
70         }
71         bool empty() {
72             return front < 0;
73         }
74     }
75 
76     Indexes index;
77     index.popFront();
78     
79     return index
80         .map!(i => X509_NAME_get_entry(name, i))
81         .map!(entry => X509_NAME_ENTRY_get_data(entry))
82         .map!(common_name => (cast(const(char)*)common_name.data)[0..common_name.length].idup)
83         .array;
84 }
85 
86 private bool matches(const(char)[] patternStr, const(char)[] hostnameStr) {
87     import std..string : split;
88 
89     auto pattern = patternStr.split(".");
90     auto hostname = hostnameStr.split(".");
91 
92     foreach_reverse(i; 0..hostname.length) {
93         if (pattern.length <= i) {
94             return false;
95         }
96 
97         if (pattern[i] == "*") {
98             return true;
99         }
100 
101         if (pattern[i] != hostname[i]) {
102             return false;
103         }
104     }
105 
106     return true;
107 }
108 
109 struct BioPair {
110     BIO* read_;
111     BIO* write_;
112     ubyte[] pending;
113 
114     @disable this(this);
115     this(int)  {
116         BIO_new_bio_pair(&read_, 0, &write_, 0);
117         enforce(read_ !is null);
118         enforce(write_ !is null);
119     }
120 
121     ~this() nothrow {
122         if (!read_) return;
123         BIO_free(read_);
124         BIO_free(write_);
125         read_ = null;
126         write_ = null;
127     }
128 
129     void write(const(ubyte)[] data) {
130         scope(success) pending ~= data;
131         if (pending.length > 0) {
132             writeSome(pending);
133             if (pending.length > 0) return;
134         }
135 
136         writeSome(data);
137     }
138 
139     bool flushPending() {
140         if (pending.length == 0) return false;
141 
142         writeSome(pending);
143         return true;
144     }
145 
146     private void writeSome(ref inout(ubyte)[] data) {
147         int r = BIO_write(write_, data.ptr, cast(int)data.length);
148         if (r < 0) throw new OpensslError(r, 0);
149         data = data[r..$];
150     }
151 
152     immutable(const(ubyte)[]) read() {
153         import std.exception : assumeUnique;
154         ubyte[] data = new ubyte[1024];
155         int r = BIO_read(read_, data.ptr, cast(int)data.length);
156         if (r < 0) {
157             return null;
158         }
159         data.length = r;
160         return assumeUnique(data);
161     }
162 }
163 
164 class OpensslError : Exception {
165     this(int ret, int err, string file = __FILE__, size_t line = __LINE__) {
166         import deimos.openssl.err : ERR_error_string, ERR_get_error;
167         import std..string : fromStringz;
168 
169         string msg;
170         auto errNum = ERR_get_error();
171         if (errNum != 0) {
172             msg = ERR_error_string(errNum, null).fromStringz.idup;
173         }
174         super("OpensslError ret=%s err=%s, msg %s".format(ret, err, msg), file, line);
175     }
176 }
177 
178 class UnderlyingShutdown : Exception {
179     this() {
180         super("Underlying connection shutdown when expecting data");
181     }
182 }
183 
184 class TlsContext {
185 private:
186     SSL_CTX* ctx;
187 
188 public:
189     this() {
190         ctx = SSL_CTX_new(SSLv23_client_method());
191         enforce(ctx !is null);
192     }
193 
194     this(string serverChainPath, string serverKeyPath) {
195         import std..string : toStringz;
196 
197         ctx = SSL_CTX_new(SSLv23_server_method());
198         enforce(ctx !is null);
199         int rc = SSL_CTX_use_certificate_chain_file(ctx, serverChainPath.toStringz);
200         if (rc <= 0) throw new OpensslError(rc, 0);
201         rc = SSL_CTX_use_PrivateKey_file(ctx, serverKeyPath.toStringz, SSL_FILETYPE_PEM);
202         if (rc <= 0) throw new OpensslError(rc, 0);
203     }
204 
205     ~this() nothrow {
206         SSL_CTX_free(ctx);
207     }
208 
209     void load_verify_locations(string cafile) {
210         import std..string : toStringz;
211 
212         int rc = SSL_CTX_load_verify_locations(ctx, cafile.toStringz, null);
213         if (rc <= 0) {
214             throw new OpensslError(rc, 0);
215         }
216     }
217 }
218 
219 class TlsStream : Stream {
220 private:
221     Stream underlying;
222     TlsContext ctx;
223     SSL* ssl;
224     BioPair tlsWrite;
225     BioPair tlsRead;
226     ubyte[] readBuffer;
227 
228     enum Want : int {
229         Success = 0,
230         Read = -1,
231         Write = -2
232     }
233 
234     Want tryOperate(alias a, Args...)(Args args) {
235         import deimos.openssl.err : ERR_clear_error;
236 
237         ERR_clear_error();
238         int ret = a(ssl, args);
239         if (ret < 0) {
240             int err = SSL_get_error(ssl, ret);
241             if (err == SSL_ERROR_WANT_READ) {
242                 return Want.Read;
243             } else if (err == SSL_ERROR_WANT_WRITE) {
244                 return Want.Write;
245             } else {
246                 throw new OpensslError(ret, err);
247             }
248         }
249         return cast(Want)ret;
250     }
251 
252     Promise!int operate(alias a, Args...)(Args args) nothrow {
253         import upromised.promise : break_, continue_, do_while;
254         
255         int r;
256         return do_while(() {
257             auto want = tryOperate!a(args);
258             return do_while(() {
259                 auto toWrite = tlsWrite.read();
260                 if (toWrite.length > 0) {
261                     return underlying.write(toWrite)
262                     .then(() => continue_);
263                 } else {
264                     return break_;
265                 }
266             }).then(() {
267                 if (want == Want.Read) {
268                     if (tlsRead.flushPending()) {
269                         return continue_;
270                     } else {
271                         return underlying.read().next().then((data) {
272                             if (data.eof) {
273                                 throw new UnderlyingShutdown;
274                             }
275                             tlsRead.write(data.value);
276                             return continue_;
277                         });
278                     }
279                 } else if (want >= Want.Success) {
280                     r = want;
281                     return break_;
282                 } else {
283                     return continue_;
284                 }
285             });
286         }).then(() => r);
287     }
288     
289 public:
290     this(Stream stream, TlsContext ctx) {
291         underlying = stream;
292         this.ctx = ctx;
293         ssl = SSL_new(ctx.ctx);
294         enforce(ssl !is null);
295         tlsWrite = BioPair(0);
296         tlsRead = BioPair(0);
297         SSL_set_bio(ssl, tlsRead.read_, tlsWrite.write_);
298         readBuffer.length = 1024;
299     }
300 
301     Promise!void accept() nothrow {
302         return operate!(SSL_accept).then((_) {});
303     }
304 
305     Promise!void connect(string hostname = null) nothrow {
306         import core.stdc.stdlib : free;
307         import std.algorithm : any, map;
308         import std..string : fromStringz, toStringz;
309 
310         extern(C) int function(int, X509_STORE_CTX*) verify;
311 
312         if (hostname !is null) {
313             HostnameExData.set(this.ssl, hostname.ccopy);
314             SSL_set_tlsext_host_name(this.ssl, cast(char*)hostname.toStringz);
315             verify = (preverified, ctx) {
316                 if (preverified == 0) {
317                     return 0;
318                 }
319 
320                 if (X509_STORE_CTX_get_error_depth(ctx) != 0) {
321                     return 1;
322                 }
323                 
324                 auto ssl = cast(SSL*) X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
325                 const(char)[] hostname = HostnameExData.get(ssl).fromStringz;
326                 auto x509 = X509_STORE_CTX_get_current_cert(ctx);
327 
328                 if (x509.alternativeNames.map!((an) => an.matches(hostname)).any) {
329                     return 1;
330                 }
331 
332                 string[] commonNames = x509.commonNames;
333                 if (commonNames.length > 0 && commonNames[$-1].matches(hostname)) {
334                     return 1;
335                 }
336 
337                 return 0;
338             };
339         }
340 
341         SSL_set_verify(this.ssl, SSL_VERIFY_PEER, verify);
342         return operate!(SSL_connect).then((_) {}).finall(() {
343             auto hostnameCopy = HostnameExData.get(this.ssl);
344             if (hostnameCopy !is null) {
345                 free(cast(void*)hostnameCopy);
346                 HostnameExData.set(this.ssl, null);
347             }
348         });
349     }
350 
351     override Promise!void write(immutable(ubyte)[] data) nothrow {
352         return operate!(SSL_write)(data.ptr, cast(int)data.length).then((a) { enforce(a == data.length);});
353     }
354 
355     override Promise!void shutdown() nothrow {
356         return operate!(SSL_shutdown,).then((a) => underlying.shutdown);
357     }
358 
359     override PromiseIterator!(const(ubyte)[]) read() nothrow {
360         return new class PromiseIterator!(const(ubyte)[]) {
361             override Promise!ItValue next(Promise!bool) {
362                 return readOne()
363                 .then((chunk) => chunk.length > 0 ? ItValue(false, chunk) : ItValue(true));
364             }
365         };
366     }
367 
368     override Promise!void close() nothrow {
369         return underlying.close();
370     }
371 protected:
372     Promise!(const(ubyte)[]) readOne() nothrow {
373         import std.algorithm : swap;
374 
375         const(ubyte)[] r;
376         return operate!(SSL_read)(readBuffer.ptr, cast(int)readBuffer.length).then((int read) {
377             r = readBuffer[0..read];
378         }).except((Exception e) {
379             if ((cast(UnderlyingShutdown)e) !is null) {
380             } else {
381                 throw e;
382             }
383         }).then(() => r);
384     }
385 }
386