xref: /trafficserver/iocore/net/SSLCertLookup.cc (revision 4cfd5a73)
1 /** @file
2 
3   SSL Context management
4 
5   @section license License
6 
7   Licensed to the Apache Software Foundation (ASF) under one
8   or more contributor license agreements.  See the NOTICE file
9   distributed with this work for additional information
10   regarding copyright ownership.  The ASF licenses this file
11   to you under the Apache License, Version 2.0 (the
12   "License"); you may not use this file except in compliance
13   with the License.  You may obtain a copy of the License at
14 
15       http://www.apache.org/licenses/LICENSE-2.0
16 
17   Unless required by applicable law or agreed to in writing, software
18   distributed under the License is distributed on an "AS IS" BASIS,
19   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20   See the License for the specific language governing permissions and
21   limitations under the License.
22  */
23 
24 #include "P_SSLCertLookup.h"
25 
26 #include "tscore/ink_config.h"
27 #include "tscore/I_Layout.h"
28 #include "tscore/MatcherUtils.h"
29 #include "tscore/Regex.h"
30 #include "tscore/Trie.h"
31 #include "tscore/BufferWriter.h"
32 #include "tscore/bwf_std_format.h"
33 #include "tscore/TestBox.h"
34 
35 #include "I_EventSystem.h"
36 
37 #include "P_SSLUtils.h"
38 #include "P_SSLConfig.h"
39 #include "SSLSessionTicket.h"
40 
41 #include <unordered_map>
42 #include <utility>
43 #include <vector>
44 #include <algorithm>
45 
46 struct SSLAddressLookupKey {
SSLAddressLookupKeySSLAddressLookupKey47   explicit SSLAddressLookupKey(const IpEndpoint &ip)
48   {
49     // For IP addresses, the cache key is the hex address with the port concatenated. This makes the
50     // lookup insensitive to address formatting and also allow the longest match semantic to produce
51     // different matches if there is a certificate on the port.
52 
53     ts::FixedBufferWriter w{key, sizeof(key)};
54     w.print("{}", ts::bwf::Hex_Dump(ip)); // dump as raw hex bytes, don't format as IP address.
55     if (in_port_t port = ip.host_order_port(); port) {
56       sep = static_cast<unsigned char>(w.size());
57       w.print(".{:x}", port);
58     }
59     w.write('\0'); // force C-string termination.
60   }
61 
62   const char *
getSSLAddressLookupKey63   get() const
64   {
65     return key;
66   }
67   void
splitSSLAddressLookupKey68   split()
69   {
70     key[sep] = '\0';
71   }
72   void
unsplitSSLAddressLookupKey73   unsplit()
74   {
75     key[sep] = '.';
76   }
77 
78 private:
79   char key[(TS_IP6_SIZE * 2) /* hex addr */ + 1 /* dot */ + 4 /* port */ + 1 /* nullptr */];
80   unsigned char sep = 0; // offset of address/port separator
81 };
82 
83 struct SSLContextStorage {
84 public:
85   SSLContextStorage();
86   ~SSLContextStorage();
87 
88   /// Add a cert context to storage
89   /// @return The @a host_store index or -1 on error.
90   int insert(const char *name, SSLCertContext const &cc);
91 
92   /// Add a cert context to storage.
93   /// @a idx must be a value returned by a previous call to insert.
94   /// This creates an alias, a different @a name referring to the same
95   /// cert context.
96   /// @return @a idx
97   int insert(const char *name, int idx);
98   SSLCertContext *lookup(const char *name);
99   void printWildDomains() const;
100   unsigned
countSSLContextStorage101   count() const
102   {
103     return this->ctx_store.size();
104   }
105   SSLCertContext *
getSSLContextStorage106   get(unsigned i)
107   {
108     return &this->ctx_store[i];
109   }
110 
111 private:
112   /** A struct that can be stored a @c Trie.
113       It contains the index of the real certificate and the
114       linkage required by @c Trie.
115   */
116   struct ContextRef {
ContextRefSSLContextStorage::ContextRef117     ContextRef() {}
ContextRefSSLContextStorage::ContextRef118     explicit ContextRef(int n) : idx(n) {}
119     void
PrintSSLContextStorage::ContextRef120     Print() const
121     {
122       Debug("ssl", "Item=%p SSL_CTX=#%d", this, idx);
123     }
124     int idx = -1;           ///< Index in the context store.
125     LINK(ContextRef, link); ///< Require by @c Trie
126   };
127 
128   /// We can only match one layer with the wildcards
129   /// This table stores the wildcarded subdomain
130   std::unordered_map<std::string, int> wilddomains;
131   /// Contexts stored by IP address or FQDN
132   std::unordered_map<std::string, int> hostnames;
133   /// List for cleanup.
134   /// Exactly one pointer to each SSL context is stored here.
135   std::vector<SSLCertContext> ctx_store;
136 
137   /// Add a context to the clean up list.
138   /// @return The index of the added context.
139   int store(SSLCertContext const &cc);
140 };
141 
142 namespace
143 {
144 /** Copy @a src to @a dst, transforming to lower case.
145  *
146  * @param src Input string.
147  * @param dst Output buffer.
148  */
149 inline void
transform_lower(std::string_view src,ts::MemSpan<char> dst)150 transform_lower(std::string_view src, ts::MemSpan<char> dst)
151 {
152   if (src.size() > dst.size() - 1) { // clip @a src, reserving space for the terminal nul.
153     src = std::string_view{src.data(), dst.size() - 1};
154   }
155   auto final = std::transform(src.begin(), src.end(), dst.data(), [](char c) -> char { return std::tolower(c); });
156   *final++   = '\0';
157 }
158 } // namespace
159 
160 // Zero out and free the heap space allocated for ticket keys to avoid leaking secrets.
161 // The first several bytes stores the number of keys and the rest stores the ticket keys.
162 void
ticket_block_free(void * ptr)163 ticket_block_free(void *ptr)
164 {
165   if (ptr) {
166     ssl_ticket_key_block *key_block_ptr = static_cast<ssl_ticket_key_block *>(ptr);
167     unsigned num_ticket_keys            = key_block_ptr->num_keys;
168     memset(ptr, 0, sizeof(ssl_ticket_key_block) + num_ticket_keys * sizeof(ssl_ticket_key_t));
169   }
170   ats_free(ptr);
171 }
172 
173 ssl_ticket_key_block *
ticket_block_alloc(unsigned count)174 ticket_block_alloc(unsigned count)
175 {
176   ssl_ticket_key_block *ptr;
177   size_t nbytes = sizeof(ssl_ticket_key_block) + count * sizeof(ssl_ticket_key_t);
178 
179   ptr = static_cast<ssl_ticket_key_block *>(ats_malloc(nbytes));
180   memset(ptr, 0, nbytes);
181   ptr->num_keys = count;
182 
183   return ptr;
184 }
185 ssl_ticket_key_block *
ticket_block_create(char * ticket_key_data,int ticket_key_len)186 ticket_block_create(char *ticket_key_data, int ticket_key_len)
187 {
188   ssl_ticket_key_block *keyblock = nullptr;
189   unsigned num_ticket_keys       = ticket_key_len / sizeof(ssl_ticket_key_t);
190   if (num_ticket_keys == 0) {
191     Error("SSL session ticket key is too short (>= 48 bytes are required)");
192     goto fail;
193   }
194   Debug("ssl", "Create %d ticket key blocks", num_ticket_keys);
195 
196   keyblock = ticket_block_alloc(num_ticket_keys);
197 
198   // Slurp all the keys in the ticket key file. We will encrypt with the first key, and decrypt
199   // with any key (for rotation purposes).
200   for (unsigned i = 0; i < num_ticket_keys; ++i) {
201     const char *data = (const char *)ticket_key_data + (i * sizeof(ssl_ticket_key_t));
202 
203     memcpy(keyblock->keys[i].key_name, data, sizeof(keyblock->keys[i].key_name));
204     memcpy(keyblock->keys[i].hmac_secret, data + sizeof(keyblock->keys[i].key_name), sizeof(keyblock->keys[i].hmac_secret));
205     memcpy(keyblock->keys[i].aes_key, data + sizeof(keyblock->keys[i].key_name) + sizeof(keyblock->keys[i].hmac_secret),
206            sizeof(keyblock->keys[i].aes_key));
207   }
208 
209   return keyblock;
210 
211 fail:
212   ticket_block_free(keyblock);
213   return nullptr;
214 }
215 
216 ssl_ticket_key_block *
ssl_create_ticket_keyblock(const char * ticket_key_path)217 ssl_create_ticket_keyblock(const char *ticket_key_path)
218 {
219 #if TS_HAVE_OPENSSL_SESSION_TICKETS
220   ats_scoped_str ticket_key_data;
221   int ticket_key_len;
222   ssl_ticket_key_block *keyblock = nullptr;
223 
224   if (ticket_key_path != nullptr) {
225     ticket_key_data = readIntoBuffer(ticket_key_path, __func__, &ticket_key_len);
226     if (!ticket_key_data) {
227       Error("failed to read SSL session ticket key from %s", (const char *)ticket_key_path);
228       goto fail;
229     }
230     keyblock = ticket_block_create(ticket_key_data, ticket_key_len);
231   } else {
232     // Generate a random ticket key
233     ssl_ticket_key_t key;
234     RAND_bytes(reinterpret_cast<unsigned char *>(&key), sizeof(key));
235     keyblock = ticket_block_create(reinterpret_cast<char *>(&key), sizeof(key));
236   }
237 
238   return keyblock;
239 
240 fail:
241   ticket_block_free(keyblock);
242   return nullptr;
243 
244 #else  /* !TS_HAVE_OPENSSL_SESSION_TICKETS */
245   (void)ticket_key_path;
246   return nullptr;
247 #endif /* TS_HAVE_OPENSSL_SESSION_TICKETS */
248 }
249 
SSLCertContext(SSLCertContext const & other)250 SSLCertContext::SSLCertContext(SSLCertContext const &other)
251 {
252   opt        = other.opt;
253   userconfig = other.userconfig;
254   keyblock   = other.keyblock;
255   std::lock_guard<std::mutex> lock(other.ctx_mutex);
256   ctx = other.ctx;
257 }
258 
259 SSLCertContext &
operator =(SSLCertContext const & other)260 SSLCertContext::operator=(SSLCertContext const &other)
261 {
262   if (&other != this) {
263     this->opt        = other.opt;
264     this->userconfig = other.userconfig;
265     this->keyblock   = other.keyblock;
266     std::lock_guard<std::mutex> lock(other.ctx_mutex);
267     this->ctx = other.ctx;
268   }
269   return *this;
270 }
271 
272 shared_SSL_CTX
getCtx()273 SSLCertContext::getCtx()
274 {
275   std::lock_guard<std::mutex> lock(ctx_mutex);
276   return ctx;
277 }
278 
279 void
setCtx(shared_SSL_CTX sc)280 SSLCertContext::setCtx(shared_SSL_CTX sc)
281 {
282   std::lock_guard<std::mutex> lock(ctx_mutex);
283   ctx = std::move(sc);
284 }
285 
SSLCertLookup()286 SSLCertLookup::SSLCertLookup() : ssl_storage(new SSLContextStorage()), ssl_default(nullptr), is_valid(true) {}
287 
~SSLCertLookup()288 SSLCertLookup::~SSLCertLookup()
289 {
290   delete this->ssl_storage;
291 }
292 
293 SSLCertContext *
find(const char * address) const294 SSLCertLookup::find(const char *address) const
295 {
296   return this->ssl_storage->lookup(address);
297 }
298 
299 SSLCertContext *
find(const IpEndpoint & address) const300 SSLCertLookup::find(const IpEndpoint &address) const
301 {
302   SSLCertContext *cc;
303   SSLAddressLookupKey key(address);
304 
305   // First try the full address.
306   if ((cc = this->ssl_storage->lookup(key.get()))) {
307     return cc;
308   }
309 
310   // If that failed, try the address without the port.
311   if (address.port()) {
312     key.split();
313     return this->ssl_storage->lookup(key.get());
314   }
315 
316   return nullptr;
317 }
318 
319 int
insert(const char * name,SSLCertContext const & cc)320 SSLCertLookup::insert(const char *name, SSLCertContext const &cc)
321 {
322   return this->ssl_storage->insert(name, cc);
323 }
324 
325 int
insert(const IpEndpoint & address,SSLCertContext const & cc)326 SSLCertLookup::insert(const IpEndpoint &address, SSLCertContext const &cc)
327 {
328   SSLAddressLookupKey key(address);
329   return this->ssl_storage->insert(key.get(), cc);
330 }
331 
332 unsigned
count() const333 SSLCertLookup::count() const
334 {
335   return ssl_storage->count();
336 }
337 
338 SSLCertContext *
get(unsigned i) const339 SSLCertLookup::get(unsigned i) const
340 {
341   return ssl_storage->get(i);
342 }
343 
SSLContextStorage()344 SSLContextStorage::SSLContextStorage() {}
345 
~SSLContextStorage()346 SSLContextStorage::~SSLContextStorage() {}
347 
348 int
store(SSLCertContext const & cc)349 SSLContextStorage::store(SSLCertContext const &cc)
350 {
351   this->ctx_store.push_back(cc);
352   return this->ctx_store.size() - 1;
353 }
354 
355 int
insert(const char * name,SSLCertContext const & cc)356 SSLContextStorage::insert(const char *name, SSLCertContext const &cc)
357 {
358   int idx = this->store(cc);
359   idx     = this->insert(name, idx);
360   if (idx < 0) {
361     this->ctx_store.pop_back();
362   }
363   return idx;
364 }
365 
366 int
insert(const char * name,int idx)367 SSLContextStorage::insert(const char *name, int idx)
368 {
369   ats_wildcard_matcher wildcard;
370   char lower_case_name[TS_MAX_HOST_NAME_LEN + 1];
371   transform_lower(name, lower_case_name);
372 
373   shared_SSL_CTX ctx = this->ctx_store[idx].getCtx();
374   if (wildcard.match(lower_case_name)) {
375     // Strip the wildcard and store the subdomain
376     const char *subdomain = index(lower_case_name, '*');
377     if (subdomain && subdomain[1] == '.') {
378       subdomain += 2; // Move beyond the '.'
379     } else {
380       subdomain = nullptr;
381     }
382     if (subdomain) {
383       if (auto it = this->wilddomains.find(subdomain); it != this->wilddomains.end()) {
384         Debug("ssl", "previously indexed '%s' with SSL_CTX #%d, cannot index it with SSL_CTX #%d now", lower_case_name, it->second,
385               idx);
386         idx = -1;
387       } else {
388         this->wilddomains.emplace(subdomain, idx);
389         Debug("ssl", "indexed '%s' with SSL_CTX %p [%d]", lower_case_name, ctx.get(), idx);
390       }
391     }
392   } else {
393     if (auto it = this->hostnames.find(lower_case_name); it != this->hostnames.end() && idx != it->second) {
394       Debug("ssl", "previously indexed '%s' with SSL_CTX %d, cannot index it with SSL_CTX #%d now", lower_case_name, it->second,
395             idx);
396       idx = -1;
397     } else {
398       this->hostnames.emplace(lower_case_name, idx);
399       Debug("ssl", "indexed '%s' with SSL_CTX %p [%d]", lower_case_name, ctx.get(), idx);
400     }
401   }
402   return idx;
403 }
404 
405 void
printWildDomains() const406 SSLContextStorage::printWildDomains() const
407 {
408   for (auto &&it : this->wilddomains) {
409     Debug("ssl", "Stored wilddomain %s", it.first.c_str());
410   }
411 }
412 
413 SSLCertContext *
lookup(const char * name)414 SSLContextStorage::lookup(const char *name)
415 {
416   // First look for an exact name match
417   if (auto it = this->hostnames.find(name); it != this->hostnames.end()) {
418     return &(this->ctx_store[it->second]);
419   }
420   // Try lower casing it
421   char lower_case_name[TS_MAX_HOST_NAME_LEN + 1];
422   transform_lower(name, lower_case_name);
423   if (auto it_lower = this->hostnames.find(lower_case_name); it_lower != this->hostnames.end()) {
424     return &(this->ctx_store[it_lower->second]);
425   }
426 
427   // Then strip off the top domain name and look for a wildcard domain match
428   const char *subdomain = index(lower_case_name, '.');
429   if (subdomain) {
430     ++subdomain; // Move beyond the '.'
431     if (auto it = this->wilddomains.find(subdomain); it != this->wilddomains.end()) {
432       return &(this->ctx_store[it->second]);
433     }
434   }
435   return nullptr;
436 }
437 
438 #if TS_HAS_TESTS
439 
440 static char *
reverse_dns_name(const char * hostname,char (& reversed)[TS_MAX_HOST_NAME_LEN+1])441 reverse_dns_name(const char *hostname, char (&reversed)[TS_MAX_HOST_NAME_LEN + 1])
442 {
443   char *ptr        = reversed + sizeof(reversed);
444   const char *part = hostname;
445 
446   *(--ptr) = '\0'; // NUL-terminate
447 
448   while (*part) {
449     ssize_t len    = strcspn(part, ".");
450     ssize_t remain = ptr - reversed;
451 
452     if (remain < (len + 1)) {
453       return nullptr;
454     }
455 
456     ptr -= len;
457     memcpy(ptr, part, len);
458 
459     // Skip to the next domain component. This will take us to either a '.' or a NUL.
460     // If it's a '.' we need to skip over it.
461     part += len;
462     if (*part == '.') {
463       ++part;
464       *(--ptr) = '.';
465     }
466   }
467   transform_lower(ptr, {ptr, strlen(ptr) + 1});
468 
469   return ptr;
470 }
471 
REGRESSION_TEST(SSLWildcardMatch)472 REGRESSION_TEST(SSLWildcardMatch)(RegressionTest *t, int /* atype ATS_UNUSED */, int *pstatus)
473 {
474   TestBox box(t, pstatus);
475   ats_wildcard_matcher wildcard;
476 
477   box = REGRESSION_TEST_PASSED;
478 
479   box.check(wildcard.match("foo.com") == false, "foo.com is not a wildcard");
480   box.check(wildcard.match("*.foo.com") == true, "*.foo.com is a wildcard");
481   box.check(wildcard.match("bar*.foo.com") == false, "bar*.foo.com not a wildcard");
482   box.check(wildcard.match("*") == false, "* is not a wildcard");
483   box.check(wildcard.match("") == false, "'' is not a wildcard");
484 }
485 
REGRESSION_TEST(SSLReverseHostname)486 REGRESSION_TEST(SSLReverseHostname)(RegressionTest *t, int /* atype ATS_UNUSED */, int *pstatus)
487 {
488   TestBox box(t, pstatus);
489 
490   char reversed[TS_MAX_HOST_NAME_LEN + 1];
491 
492 #define _R(name) reverse_dns_name(name, reversed)
493 
494   box = REGRESSION_TEST_PASSED;
495 
496   box.check(strcmp(_R("foo.com"), "com.foo") == 0, "reversed foo.com");
497   box.check(strcmp(_R("bar.foo.com"), "com.foo.bar") == 0, "reversed bar.foo.com");
498   box.check(strcmp(_R("foo"), "foo") == 0, "reversed foo");
499   box.check(strcmp(_R("foo.Com"), "Com.foo") != 0, "mixed case reversed foo.com mismatch");
500   box.check(strcmp(_R("foo.Com"), "com.foo") == 0, "mixed case reversed foo.com match");
501 
502 #undef _R
503 }
504 
505 #endif // TS_HAS_TESTS
506