connection.cpp

00001 /*
00002   Copyright (c) 2004-2006 by Jakob Schroeter <js@camaya.net>
00003   This file is part of the gloox library. http://camaya.net/gloox
00004 
00005   This software is distributed under a license. The full license
00006   agreement can be found in the file LICENSE in this distribution.
00007   This software may not be copied, modified, sold or distributed
00008   other than expressed in the named license agreement.
00009 
00010   This software is distributed without any warranty.
00011 */
00012 
00013 
00014 
00015 #include "gloox.h"
00016 
00017 #include "compression.h"
00018 #include "connection.h"
00019 #include "dns.h"
00020 #include "logsink.h"
00021 #include "prep.h"
00022 #include "parser.h"
00023 
00024 #ifdef __MINGW32__
00025 #include <winsock.h>
00026 #endif
00027 
00028 #ifndef WIN32
00029 #include <sys/types.h>
00030 #include <sys/socket.h>
00031 #include <sys/select.h>
00032 #include <unistd.h>
00033 #else
00034 #include <winsock.h>
00035 #endif
00036 
00037 #ifdef USE_WINTLS
00038 # include <schannel.h>
00039 #endif
00040 
00041 #include <time.h>
00042 
00043 #include <string>
00044 #include <sstream>
00045 
00046 namespace gloox
00047 {
00048 
00049   Connection::Connection( Parser *parser, const LogSink& logInstance, const std::string& server,
00050                           unsigned short port )
00051     : m_parser( parser ), m_state ( StateDisconnected ), m_disconnect ( ConnNoError ),
00052       m_logInstance( logInstance ), m_compression( 0 ), m_buf( 0 ),
00053       m_server( Prep::idna( server ) ), m_port( port ), m_socket( -1 ), m_bufsize( 17000 ),
00054       m_cancel( true ), m_secure( false ), m_fdRequested( false ), m_enableCompression( false )
00055   {
00056     m_buf = (char*)calloc( m_bufsize + 1, sizeof( char ) );
00057 #ifdef USE_OPENSSL
00058     m_ssl = 0;
00059 #endif
00060   }
00061 
00062   Connection::~Connection()
00063   {
00064     cleanup();
00065     free( m_buf );
00066     m_buf = 0;
00067     m_parser = 0;
00068   }
00069 
00070 #ifdef HAVE_TLS
00071   void Connection::setClientCert( const std::string& clientKey, const std::string& clientCerts )
00072   {
00073     m_clientKey = clientKey;
00074     m_clientCerts = clientCerts;
00075   }
00076 #endif
00077 
00078 #if defined( USE_OPENSSL )
00079   bool Connection::tlsHandshake()
00080   {
00081     SSL_library_init();
00082     SSL_CTX *sslCTX = SSL_CTX_new( TLSv1_client_method() );
00083     if( !sslCTX )
00084       return false;
00085 
00086     if( !SSL_CTX_set_cipher_list( sslCTX, "HIGH:MEDIUM:AES:@STRENGTH" ) )
00087       return false;
00088 
00089     StringList::const_iterator it = m_cacerts.begin();
00090     for( ; it != m_cacerts.end(); ++it )
00091       SSL_CTX_load_verify_locations( sslCTX, (*it).c_str(), NULL );
00092 
00093     if( !m_clientKey.empty() && !m_clientCerts.empty() )
00094     {
00095       SSL_CTX_use_certificate_chain_file( sslCTX, m_clientCerts.c_str() );
00096       SSL_CTX_use_PrivateKey_file( sslCTX, m_clientKey.c_str(), SSL_FILETYPE_PEM );
00097     }
00098 
00099     m_ssl = SSL_new( sslCTX );
00100     SSL_set_connect_state( m_ssl );
00101 
00102     BIO *socketBio = BIO_new_socket( m_socket, BIO_NOCLOSE );
00103     if( !socketBio )
00104       return false;
00105 
00106     SSL_set_bio( m_ssl, socketBio, socketBio );
00107     SSL_set_mode( m_ssl, SSL_MODE_AUTO_RETRY );
00108 
00109     if( !SSL_connect( m_ssl ) )
00110       return false;
00111 
00112     m_secure = true;
00113 
00114     int res = SSL_get_verify_result( m_ssl );
00115     if( res != X509_V_OK )
00116       m_certInfo.status = CertInvalid;
00117     else
00118       m_certInfo.status = CertOk;
00119 
00120     X509 *peer;
00121     peer = SSL_get_peer_certificate( m_ssl );
00122     if( peer )
00123     {
00124       char peer_CN[256];
00125       X509_NAME_get_text_by_NID( X509_get_issuer_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00126       m_certInfo.issuer = peer_CN;
00127       X509_NAME_get_text_by_NID( X509_get_subject_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00128       m_certInfo.server = peer_CN;
00129       std::string p;
00130       p.assign( peer_CN );
00131       int (*pf)( int ) = tolower;
00132       transform( p.begin(), p.end(), p.begin(), pf );
00133       if( p != m_server )
00134         m_certInfo.status |= CertWrongPeer;
00135     }
00136     else
00137     {
00138       m_certInfo.status = CertInvalid;
00139     }
00140 
00141     const char *tmp;
00142     tmp = SSL_get_cipher_name( m_ssl );
00143     if( tmp )
00144       m_certInfo.cipher = tmp;
00145 
00146     tmp = SSL_get_cipher_version( m_ssl );
00147     if( tmp )
00148       m_certInfo.protocol = tmp;
00149 
00150     return true;
00151   }
00152 
00153   inline bool Connection::tls_send( const void *data, size_t len )
00154   {
00155     int ret;
00156     ret = SSL_write( m_ssl, data, len );
00157     return true;
00158   }
00159 
00160   inline int Connection::tls_recv( void *data, size_t len )
00161   {
00162     return SSL_read( m_ssl, data, len );
00163   }
00164 
00165   inline bool Connection::tls_dataAvailable()
00166   {
00167     return false; // SSL_pending( m_ssl ); // FIXME: crashes
00168   }
00169 
00170   inline void Connection::tls_cleanup()
00171   {
00172     SSL_shutdown( m_ssl );
00173     SSL_free( m_ssl );
00174   }
00175 
00176 #elif defined( USE_GNUTLS )
00177   bool Connection::tlsHandshake()
00178   {
00179     const int protocolPriority[] = { GNUTLS_TLS1, GNUTLS_SSL3, 0 };
00180     const int kxPriority[]       = { GNUTLS_KX_RSA, 0 };
00181     const int cipherPriority[]   = { GNUTLS_CIPHER_AES_256_CBC, GNUTLS_CIPHER_AES_128_CBC,
00182                                              GNUTLS_CIPHER_3DES_CBC, GNUTLS_CIPHER_ARCFOUR, 0 };
00183     const int compPriority[]     = { GNUTLS_COMP_ZLIB, GNUTLS_COMP_NULL, 0 };
00184     const int macPriority[]      = { GNUTLS_MAC_SHA, GNUTLS_MAC_MD5, 0 };
00185 
00186     if( gnutls_global_init() != 0 )
00187       return false;
00188 
00189     if( gnutls_certificate_allocate_credentials( &m_credentials ) < 0 )
00190       return false;
00191 
00192     StringList::const_iterator it = m_cacerts.begin();
00193     for( ; it != m_cacerts.end(); ++it )
00194       gnutls_certificate_set_x509_trust_file( m_credentials, (*it).c_str(), GNUTLS_X509_FMT_PEM );
00195 
00196     if( !m_clientKey.empty() && !m_clientCerts.empty() )
00197     {
00198       gnutls_certificate_set_x509_key_file( m_credentials, m_clientKey.c_str(),
00199                                             m_clientCerts.c_str(), GNUTLS_X509_FMT_PEM );
00200     }
00201 
00202     if( gnutls_init( &m_session, GNUTLS_CLIENT ) != 0 )
00203     {
00204       gnutls_certificate_free_credentials( m_credentials );
00205       return false;
00206     }
00207 
00208     gnutls_protocol_set_priority( m_session, protocolPriority );
00209     gnutls_cipher_set_priority( m_session, cipherPriority );
00210     gnutls_compression_set_priority( m_session, compPriority );
00211     gnutls_kx_set_priority( m_session, kxPriority );
00212     gnutls_mac_set_priority( m_session, macPriority );
00213     gnutls_credentials_set( m_session, GNUTLS_CRD_CERTIFICATE, m_credentials );
00214 
00215     gnutls_transport_set_ptr( m_session, (gnutls_transport_ptr_t)m_socket );
00216     if( gnutls_handshake( m_session ) != 0 )
00217     {
00218       gnutls_deinit( m_session );
00219       gnutls_certificate_free_credentials( m_credentials );
00220       return false;
00221     }
00222     gnutls_certificate_free_ca_names( m_credentials );
00223 
00224     m_secure = true;
00225 
00226     unsigned int status;
00227     bool error = false;
00228 
00229     if( gnutls_certificate_verify_peers2( m_session, &status ) < 0 )
00230       error = true;
00231 
00232     m_certInfo.status = 0;
00233     if( status & GNUTLS_CERT_INVALID )
00234       m_certInfo.status |= CertInvalid;
00235     if( status & GNUTLS_CERT_SIGNER_NOT_FOUND )
00236       m_certInfo.status |= CertSignerUnknown;
00237     if( status & GNUTLS_CERT_REVOKED )
00238       m_certInfo.status |= CertRevoked;
00239     if( status & GNUTLS_CERT_SIGNER_NOT_CA )
00240       m_certInfo.status |= CertSignerNotCa;
00241     const gnutls_datum_t* certList = 0;
00242     unsigned int certListSize;
00243     if( !error && ( ( certList = gnutls_certificate_get_peers( m_session, &certListSize ) ) == 0 ) )
00244       error = true;
00245 
00246     gnutls_x509_crt_t *cert = new gnutls_x509_crt_t[certListSize+1];
00247     for( unsigned int i=0; !error && ( i<certListSize ); ++i )
00248     {
00249       if( !error && ( gnutls_x509_crt_init( &cert[i] ) < 0 ) )
00250         error = true;
00251       if( !error && ( gnutls_x509_crt_import( cert[i], &certList[i], GNUTLS_X509_FMT_DER ) < 0 ) )
00252         error = true;
00253     }
00254 
00255     if( ( gnutls_x509_crt_check_issuer( cert[certListSize-1], cert[certListSize-1] ) > 0 )
00256          && certListSize > 0 )
00257       certListSize--;
00258 
00259     bool chain = true;
00260     for( unsigned int i=1; !error && ( i<certListSize ); ++i )
00261     {
00262       chain = error = !verifyAgainst( cert[i-1], cert[i] );
00263     }
00264     if( !chain )
00265       m_certInfo.status |= CertInvalid;
00266     m_certInfo.chain = chain;
00267 
00268     m_certInfo.chain = verifyAgainstCAs( cert[certListSize], 0 /*CAList*/, 0 /*CAListSize*/ );
00269 
00270     int t = (int)gnutls_x509_crt_get_expiration_time( cert[0] );
00271     if( t == -1 )
00272       error = true;
00273     else if( t < time( 0 ) )
00274       m_certInfo.status |= CertExpired;
00275     m_certInfo.date_from = t;
00276 
00277     t = (int)gnutls_x509_crt_get_activation_time( cert[0] );
00278     if( t == -1 )
00279       error = true;
00280     else if( t > time( 0 ) )
00281       m_certInfo.status |= CertNotActive;
00282     m_certInfo.date_to = t;
00283 
00284     char name[64];
00285     size_t nameSize = sizeof( name );
00286     gnutls_x509_crt_get_issuer_dn( cert[0], name, &nameSize );
00287     m_certInfo.issuer = name;
00288 
00289     nameSize = sizeof( name );
00290     gnutls_x509_crt_get_dn( cert[0], name, &nameSize );
00291     m_certInfo.server = name;
00292 
00293     const char* info;
00294     info = gnutls_compression_get_name( gnutls_compression_get( m_session ) );
00295     if( info )
00296       m_certInfo.compression = info;
00297 
00298     info = gnutls_mac_get_name( gnutls_mac_get( m_session ) );
00299     if( info )
00300       m_certInfo.mac = info;
00301 
00302     info = gnutls_cipher_get_name( gnutls_cipher_get( m_session ) );
00303     if( info )
00304       m_certInfo.cipher = info;
00305 
00306     info = gnutls_protocol_get_name( gnutls_protocol_get_version( m_session ) );
00307     if( info )
00308       m_certInfo.protocol = info;
00309 
00310     if( !gnutls_x509_crt_check_hostname( cert[0], m_server.c_str() ) )
00311       m_certInfo.status |= CertWrongPeer;
00312 
00313     for( unsigned int i=0; i<certListSize; ++i )
00314       gnutls_x509_crt_deinit( cert[i] );
00315 
00316     delete[] cert;
00317 
00318     return true;
00319   }
00320 
00321   bool Connection::verifyAgainst( gnutls_x509_crt_t cert, gnutls_x509_crt_t issuer )
00322   {
00323     unsigned int result;
00324     gnutls_x509_crt_verify( cert, &issuer, 1, 0, &result );
00325     if( result & GNUTLS_CERT_INVALID )
00326       return false;
00327 
00328     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00329       return false;
00330 
00331     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00332       return false;
00333 
00334     return true;
00335   }
00336 
00337   bool Connection::verifyAgainstCAs( gnutls_x509_crt_t cert, gnutls_x509_crt_t *CAList, int CAListSize )
00338   {
00339     unsigned int result;
00340     gnutls_x509_crt_verify( cert, CAList, CAListSize, GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT, &result );
00341     if( result & GNUTLS_CERT_INVALID )
00342       return false;
00343 
00344     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00345       return false;
00346 
00347     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00348       return false;
00349 
00350     return true;
00351   }
00352 
00353   inline bool Connection::tls_send( const void *data, size_t len )
00354   {
00355     int ret;
00356     do
00357     {
00358       ret = gnutls_record_send( m_session, data, len );
00359     }
00360     while( ( ret == GNUTLS_E_AGAIN ) || ( ret == GNUTLS_E_INTERRUPTED ) );
00361     return true;
00362   }
00363 
00364   inline int Connection::tls_recv( void *data, size_t len )
00365   {
00366     return gnutls_record_recv( m_session, data, len );
00367   }
00368 
00369   inline bool Connection::tls_dataAvailable()
00370   {
00371     return false; // gnutls_check_pending( m_session ); // FIXME: crashes
00372   }
00373 
00374   inline void Connection::tls_cleanup()
00375   {
00376     gnutls_bye( m_session, GNUTLS_SHUT_RDWR );
00377     gnutls_deinit( m_session );
00378     gnutls_certificate_free_credentials( m_credentials );
00379     gnutls_global_deinit();
00380   }
00381 
00382 #elif defined( USE_WINTLS )
00383   bool Connection::tlsHandshake()
00384   {
00385     INIT_SECURITY_INTERFACE pInitSecurityInterface;
00386 
00387     m_lib = LoadLibrary( "secur32.dll" );
00388     if( m_lib == NULL )
00389       return false;
00390 
00391     pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddress( m_lib, "InitSecurityInterfaceA" );
00392     if( pInitSecurityInterface == NULL )
00393     {
00394       FreeLibrary( m_lib );
00395       m_lib = 0;
00396       return false;
00397     }
00398 
00399     m_securityFunc = pInitSecurityInterface();
00400     if( !m_securityFunc )
00401     {
00402       FreeLibrary( m_lib );
00403       m_lib = 0;
00404       return false;
00405     }
00406 
00407     SCHANNEL_CRED schannelCred;
00408     memset( &schannelCred, 0, sizeof( schannelCred ) );
00409     memset( &m_credentials, 0, sizeof( m_credentials ) );
00410     memset( &m_context, 0, sizeof( m_context ) );
00411 
00412     schannelCred.dwVersion = SCHANNEL_CRED_VERSION;
00413     schannelCred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT;
00414     schannelCred.cSupportedAlgs = 0; // FIXME
00415 #ifdef MSVC
00416     schannelCred.dwMinimumCipherStrength = 0; // FIXME
00417     schannelCred.dwMaximumCipherStrength = 0; // FIXME
00418 #else
00419     schannelCred.dwMinimumCypherStrength = 0; // FIXME
00420     schannelCred.dwMaximumCypherStrength = 0; // FIXME
00421 #endif
00422     schannelCred.dwSessionLifespan = 0;
00423     schannelCred.dwFlags = SCH_CRED_NO_SERVERNAME_CHECK | SCH_CRED_NO_DEFAULT_CREDS |
00424                            SCH_CRED_MANUAL_CRED_VALIDATION; // FIXME check
00425 
00426     TimeStamp timeStamp;
00427     SECURITY_STATUS ret;
00428     ret = m_securityFunc->AcquireCredentialsHandleA( NULL, UNISP_NAME_A, SECPKG_CRED_OUTBOUND,
00429                                      NULL, &schannelCred, NULL,
00430                                      NULL, &m_credentials, &timeStamp );
00431     if( ret != SEC_E_OK )
00432     {
00433       printf( "AcquireCredentialsHandleA failed\n" );
00434       return false;
00435     }
00436 
00437     m_sspiFlags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR
00438                       | ISC_REQ_MUTUAL_AUTH | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT
00439                       | ISC_REQ_STREAM;
00440 
00441     SecBufferDesc outBufferDesc;
00442     SecBuffer outBuffers[1];
00443 
00444     outBuffers[0].BufferType = SECBUFFER_TOKEN;
00445     outBuffers[0].pvBuffer = NULL;
00446     outBuffers[0].cbBuffer = 0;
00447 
00448     outBufferDesc.ulVersion = SECBUFFER_VERSION;
00449     outBufferDesc.cBuffers = 1;
00450     outBufferDesc.pBuffers = outBuffers;
00451 
00452     long unsigned int sspiFlagsOut;
00453     ret = m_securityFunc->InitializeSecurityContextA( &m_credentials, NULL, NULL, m_sspiFlags, 0,
00454         SECURITY_NATIVE_DREP, NULL, 0, &m_context,
00455         &outBufferDesc, &sspiFlagsOut, &timeStamp );
00456     if( ret == SEC_I_CONTINUE_NEEDED && outBuffers[0].cbBuffer != 0 && outBuffers[0].pvBuffer != NULL )
00457     {
00458       printf( "OK: Continue needed: " );
00459 
00460       int ret = ::send( m_socket, (const char*)outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0 );
00461       if( ret == SOCKET_ERROR || ret == 0 )
00462       {
00463         m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00464         m_securityFunc->DeleteSecurityContext( &m_context );
00465         return false;
00466       }
00467 
00468       m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00469       outBuffers[0].pvBuffer = NULL;
00470     }
00471 
00472     if( !handshakeLoop() )
00473     {
00474       printf( "handshakeLoop failed\n" );
00475       return false;
00476     }
00477 
00478     ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes );
00479     if( ret != SEC_E_OK )
00480     {
00481       printf( "could not read stream attribs (sizes)\n" );
00482       return false;
00483     }
00484 printf( "maximumMessage: %ld\n", m_streamSizes.cbMaximumMessage );
00485     int maxSize = m_streamSizes.cbHeader + m_streamSizes.cbMaximumMessage + m_streamSizes.cbTrailer;
00486     m_iBuffer = (char*)malloc( maxSize );
00487     if( !m_iBuffer )
00488       return false;
00489 
00490     m_oBuffer = (char*)malloc( maxSize );
00491     if( !m_oBuffer )
00492       return false;
00493 
00494     m_bufferOffset = 0;
00495     m_messageOffset = m_oBuffer + m_streamSizes.cbHeader;
00496 
00497     SecPkgContext_Authority streamAuthority;
00498     ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_AUTHORITY, &streamAuthority );
00499     if( ret != SEC_E_OK )
00500     {
00501       printf( "could not read stream attribs (sizes)\n" );
00502       return false;
00503     }
00504     else
00505     {
00506       m_certInfo.issuer.assign( streamAuthority.sAuthorityName );
00507     }
00508 
00509     SecPkgContext_ConnectionInfo streamInfo;
00510     ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_CONNECTION_INFO, &streamInfo );
00511     if( ret != SEC_E_OK )
00512     {
00513       printf( "could not read stream attribs (sizes)\n" );
00514       return false;
00515     }
00516     else
00517     {
00518       if( streamInfo.dwProtocol == SP_PROT_TLS1_CLIENT )
00519         m_certInfo.protocol = "TLS 1.0";
00520       else
00521         m_certInfo.protocol = "unknown";
00522 
00523       std::ostringstream oss;
00524       switch( streamInfo.aiCipher )
00525       {
00526         case CALG_3DES:
00527           oss << "3DES";
00528           break;
00529         case CALG_AES_128:
00530           oss << "AES";
00531           break;
00532         case CALG_AES_256:
00533           oss << "AES";
00534           break;
00535         case CALG_DES:
00536           oss << "DES";
00537           break;
00538         case CALG_RC2:
00539           oss << "RC2";
00540           break;
00541         case CALG_RC4:
00542           oss << "RC4";
00543           break;
00544         default:
00545           oss << "unknown";
00546       }
00547 
00548       oss << " " << streamInfo.dwCipherStrength;
00549       m_certInfo.cipher = oss.str();
00550       oss.str( "" );
00551 
00552       switch( streamInfo.aiHash  )
00553       {
00554         case CALG_MD5:
00555           oss << "MD5";
00556           break;
00557         case CALG_SHA:
00558           oss << "SHA";
00559           break;
00560         default:
00561           oss << "unknown";
00562       }
00563 
00564       oss << " " << streamInfo.dwHashStrength;
00565       m_certInfo.mac = oss.str();
00566 
00567       m_certInfo.compression = "unknown";
00568     }
00569 
00570     m_secure = true;
00571 
00572     return true;
00573   }
00574 
00575   bool Connection::handshakeLoop()
00576   {
00577     const int bufsize = 65536;
00578     char *buf = (char*)malloc( bufsize );
00579     if( !buf )
00580       return false;
00581 
00582     int bufFilled = 0;
00583     int dataRecv = 0;
00584     bool doRead = true;
00585 
00586     SecBufferDesc outBufferDesc, inBufferDesc;
00587     SecBuffer outBuffers[1], inBuffers[2];
00588 
00589     SECURITY_STATUS ret = SEC_I_CONTINUE_NEEDED;
00590 
00591     while( ret == SEC_I_CONTINUE_NEEDED ||
00592            ret == SEC_E_INCOMPLETE_MESSAGE ||
00593            ret == SEC_I_INCOMPLETE_CREDENTIALS )
00594     {
00595 
00596       if( doRead )
00597       {
00598         dataRecv = ::recv( m_socket, buf + bufFilled, bufsize - bufFilled, 0 );
00599 
00600         if( dataRecv == SOCKET_ERROR || dataRecv == 0 )
00601         {
00602           break;
00603         }
00604 
00605         printf( "%d bytes handshake data received\n", dataRecv );
00606 
00607         bufFilled += dataRecv;
00608       }
00609       else
00610       {
00611         doRead = true;
00612       }
00613 
00614       outBuffers[0].BufferType = SECBUFFER_TOKEN;
00615       outBuffers[0].pvBuffer = NULL;
00616       outBuffers[0].cbBuffer = 0;
00617 
00618       outBufferDesc.ulVersion = SECBUFFER_VERSION;
00619       outBufferDesc.cBuffers = 1;
00620       outBufferDesc.pBuffers = outBuffers;
00621 
00622       inBuffers[0].BufferType = SECBUFFER_TOKEN;
00623       inBuffers[0].pvBuffer = buf;
00624       inBuffers[0].cbBuffer = bufFilled;
00625 
00626       inBuffers[1].BufferType = SECBUFFER_EMPTY;
00627       inBuffers[1].pvBuffer = NULL;
00628       inBuffers[1].cbBuffer = 0;
00629 
00630       inBufferDesc.ulVersion = SECBUFFER_VERSION;
00631       inBufferDesc.cBuffers = 2;
00632       inBufferDesc.pBuffers = inBuffers;
00633 
00634       printf( "buffers inited, calling InitializeSecurityContextA\n" );
00635       long unsigned int sspiFlagsOut;
00636       TimeStamp timeStamp;
00637       ret = m_securityFunc->InitializeSecurityContextA( &m_credentials, &m_context, NULL,
00638                                                         m_sspiFlags, 0,
00639                                                         SECURITY_NATIVE_DREP, &inBufferDesc, 0, NULL,
00640                                                         &outBufferDesc, &sspiFlagsOut, &timeStamp );
00641       if( ret == SEC_E_OK || ret == SEC_I_CONTINUE_NEEDED ||
00642           ( FAILED( ret ) && sspiFlagsOut & ISC_RET_EXTENDED_ERROR ) )
00643       {
00644         if( outBuffers[0].cbBuffer != 0 && outBuffers[0].pvBuffer != NULL )
00645         {
00646           printf( "ISCA returned, buffers not empty\n" );
00647           dataRecv = ::send( m_socket, (const char*)outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0  );
00648           if( dataRecv == SOCKET_ERROR || dataRecv == 0 )
00649           {
00650             m_securityFunc->FreeContextBuffer( &outBuffers[0].pvBuffer );
00651             m_securityFunc->DeleteSecurityContext( &m_context );
00652             free( buf );
00653             printf( "coudl not send bufer to server, exiting\n" );
00654             return false;
00655           }
00656 
00657           m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00658           outBuffers[0].pvBuffer = NULL;
00659         }
00660       }
00661 
00662       if( ret == SEC_E_INCOMPLETE_MESSAGE )
00663         continue;
00664 
00665       if( ret == SEC_E_OK )
00666       {
00667         printf( "handshake successful\n" );
00668         break;
00669       }
00670 
00671       if( FAILED( ret ) )
00672       {
00673         printf( "ISC failed: %ld\n", ret );
00674         break;
00675       }
00676 
00677       if( ret == SEC_I_INCOMPLETE_CREDENTIALS )
00678       {
00679         printf( "server requested client credentials\n" );
00680         ret = SEC_I_CONTINUE_NEEDED;
00681         continue;
00682       }
00683 
00684       if( inBuffers[1].BufferType == SECBUFFER_EXTRA )
00685       {
00686         printf("some xtra mem in inbuf\n" );
00687         MoveMemory( buf, buf + ( bufFilled - inBuffers[1].cbBuffer ),
00688                    inBuffers[1].cbBuffer );
00689 
00690         bufFilled = inBuffers[1].cbBuffer;
00691       }
00692       else
00693       {
00694         bufFilled = 0;
00695       }
00696     }
00697 
00698     if( FAILED( ret ) )
00699       m_securityFunc->DeleteSecurityContext( &m_context );
00700 
00701     free( buf );
00702 
00703     if( ret == SEC_E_OK )
00704       return true;
00705 
00706     return false;
00707   }
00708 
00709   inline bool Connection::tls_send( const void *data, size_t len )
00710   {
00711     if( len <= 0 )
00712       return true;
00713 
00714     SECURITY_STATUS ret;
00715 
00716     m_obuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
00717     m_obuffers[0].pvBuffer = m_oBuffer;
00718     m_obuffers[0].cbBuffer = m_streamSizes.cbHeader;
00719 
00720     m_obuffers[1].BufferType = SECBUFFER_DATA;
00721     m_obuffers[1].pvBuffer = m_messageOffset;
00722 
00723     m_obuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
00724     m_obuffers[2].cbBuffer = m_streamSizes.cbTrailer;
00725 
00726     m_obuffers[3].BufferType = SECBUFFER_EMPTY;
00727     m_obuffers[3].pvBuffer = NULL;
00728     m_obuffers[3].cbBuffer = 0;
00729 
00730     m_omessage.ulVersion = SECBUFFER_VERSION;
00731     m_omessage.cBuffers = 4;
00732     m_omessage.pBuffers = m_obuffers;
00733 
00734     while( len > 0 )
00735     {
00736       if( m_streamSizes.cbMaximumMessage < len )
00737       {
00738         memcpy( m_messageOffset, data, m_streamSizes.cbMaximumMessage );
00739         len -= m_streamSizes.cbMaximumMessage;
00740         m_obuffers[1].cbBuffer = m_streamSizes.cbMaximumMessage;
00741         m_obuffers[2].pvBuffer = m_messageOffset + m_streamSizes.cbMaximumMessage;
00742       }
00743       else
00744       {
00745         memcpy( m_messageOffset, data, len );
00746         m_obuffers[1].cbBuffer = len;
00747         m_obuffers[2].pvBuffer = m_messageOffset + len;
00748         len = 0;
00749       }
00750 
00751       ret = m_securityFunc->EncryptMessage( &m_context, 0, &m_omessage, 0 );
00752       if( ret != SEC_E_OK )
00753       {
00754         printf( "encryptmessage failed %ld\n", ret );
00755         return false;
00756       }
00757 
00758       int t = ::send( m_socket, m_oBuffer,
00759                       m_obuffers[0].cbBuffer + m_obuffers[1].cbBuffer + m_obuffers[2].cbBuffer, 0 );
00760       if( t == SOCKET_ERROR || t == 0 )
00761       {
00762         printf( "could not send: %d\n", WSAGetLastError() );
00763         return false;
00764       }
00765     }
00766 
00767     return true;
00768   }
00769 
00770   inline int Connection::tls_recv( void *data, size_t len )
00771   {
00772     SECURITY_STATUS ret;
00773     SecBuffer *dataBuffer = 0;
00774     int readable = 0;
00775 
00776     int maxLength = m_streamSizes.cbHeader + m_streamSizes.cbMaximumMessage + m_streamSizes.cbTrailer;
00777 
00778     printf( "bufferOffset is %d\n", m_bufferOffset );
00779 
00780     int t = ::recv( m_socket, m_iBuffer + m_bufferOffset, maxLength - m_bufferOffset, 0 );
00781     if( t == SOCKET_ERROR )
00782     {
00783       printf( "got SocketError\n" );
00784       return 0;
00785     }
00786     else if( t == 0 )
00787     {
00788       printf( "got connection close\n" );
00789       return 0;
00790     }
00791     else
00792       m_bufferOffset += t;
00793 
00794     while( m_bufferOffset )
00795     {
00796       printf( "continuing with bufferOffset: %d\n", m_bufferOffset );
00797 
00798       m_ibuffers[0].pvBuffer = m_iBuffer;
00799       m_ibuffers[0].cbBuffer = m_bufferOffset;
00800       m_ibuffers[0].BufferType = SECBUFFER_DATA;
00801 
00802       m_ibuffers[1].BufferType = SECBUFFER_EMPTY;
00803       m_ibuffers[2].BufferType = SECBUFFER_EMPTY;
00804       m_ibuffers[3].BufferType = SECBUFFER_EMPTY;
00805 
00806       m_imessage.ulVersion = SECBUFFER_VERSION;
00807       m_imessage.cBuffers = 4;
00808       m_imessage.pBuffers = m_ibuffers;
00809 
00810       ret = m_securityFunc->DecryptMessage( &m_context, &m_imessage, 0, NULL );
00811 
00812       if( ret == SEC_E_INCOMPLETE_MESSAGE )
00813       {
00814         printf( "recv'ed incomplete message\n" );
00815         return readable;
00816       }
00817 
00818 
00819   //    if( ret == SEC_I_CONTEXT_EXPIRED )
00820   //      return 0;
00821 
00822       if( ret != SEC_E_OK && ret != SEC_I_RENEGOTIATE )
00823       {
00824         printf( "DecryptMessage returned %ld\n", ret );
00825         printf( "GetLastError(): %ld\n", GetLastError() );
00826         printf( "input buffer length: %d, read in this run: %d\n", m_bufferOffset, t );
00827         return false;
00828       }
00829 
00830       m_bufferOffset = 0;
00831 
00832       for( int i = 1; i < 4; ++i )
00833       {
00834         if( dataBuffer == 0 && m_ibuffers[i].BufferType == SECBUFFER_DATA )
00835         {
00836           dataBuffer = &m_ibuffers[i];
00837         }
00838         if( m_bufferOffset == 0 && m_ibuffers[i].BufferType == SECBUFFER_EXTRA )
00839         {
00840   //         m_extraBuffer = &m_ibuffers[i];
00841   printf( "git exetra buffer, size %ld\n", m_ibuffers[i].cbBuffer );
00842 //          memcpy( m_iBuffer, m_ibuffers[i].pvBuffer, m_ibuffers[i].cbBuffer );
00843 //          m_bufferOffset = m_ibuffers[i].cbBuffer;
00844         }
00845       }
00846 
00847       if( dataBuffer )
00848       {
00849         if( dataBuffer->cbBuffer > len )
00850         {
00851           memcpy( data, dataBuffer->pvBuffer, len );
00852           return len;
00853         }
00854         else
00855         {
00856           memcpy( data, dataBuffer->pvBuffer, dataBuffer->cbBuffer );
00857           readable += dataBuffer->cbBuffer;
00858           printf( "recvbuffer (%d): %s\n", readable, data );
00859         }
00860       }
00861 
00862       if( ret == SEC_I_RENEGOTIATE )
00863       {
00864         printf( "server requested reneg\n" );
00865         ret = handshakeLoop();
00866       }
00867     }
00868 
00869     return readable;
00870   }
00871 
00872   inline bool Connection::tls_dataAvailable()
00873   {
00874     return false;
00875   }
00876 
00877   inline void Connection::tls_cleanup()
00878   {
00879     m_securityFunc->DeleteSecurityContext( &m_context );
00880   }
00881 #endif
00882 
00883 #ifdef HAVE_ZLIB
00884   bool Connection::initCompression( StreamFeature method )
00885   {
00886     delete m_compression;
00887     m_compression = 0;
00888     m_compression = new Compression( method );
00889     return true;
00890   }
00891 
00892   void Connection::enableCompression()
00893   {
00894     if( !m_compression )
00895       return;
00896 
00897     m_enableCompression = true;
00898   }
00899 #endif
00900 
00901   ConnectionState Connection::connect()
00902   {
00903     if( m_socket != -1 && m_state >= StateConnecting )
00904     {
00905       return m_state;
00906     }
00907 
00908     m_state = StateConnecting;
00909 
00910     if( m_port == ( unsigned short ) -1 )
00911       m_socket = DNS::connect( m_server, m_logInstance );
00912     else
00913       m_socket = DNS::connect( m_server, m_port, m_logInstance );
00914 
00915     if( m_socket < 0 )
00916     {
00917       switch( m_socket )
00918       {
00919         case -DNS::DNS_COULD_NOT_CONNECT:
00920           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not connect" );
00921           break;
00922         case -DNS::DNS_NO_HOSTS_FOUND:
00923           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: no hosts found" );
00924           break;
00925         case -DNS::DNS_COULD_NOT_RESOLVE:
00926           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not resolve" );
00927           break;
00928       }
00929       cleanup();
00930     }
00931     else
00932       m_state = StateConnected;
00933 
00934     m_cancel = false;
00935     return m_state;
00936   }
00937 
00938   void Connection::disconnect( ConnectionError e )
00939   {
00940     m_disconnect = e;
00941     m_cancel = true;
00942 
00943     if( m_fdRequested )
00944       cleanup();
00945   }
00946 
00947   int Connection::fileDescriptor()
00948   {
00949     m_fdRequested = true;
00950     return m_socket;
00951   }
00952 
00953   bool Connection::dataAvailable( int timeout )
00954   {
00955 #ifdef HAVE_TLS
00956     if( tls_dataAvailable() )
00957     {
00958         return true;
00959     }
00960 #endif
00961 
00962     fd_set fds;
00963     struct timeval tv;
00964 
00965     FD_ZERO( &fds );
00966     FD_SET( m_socket, &fds );
00967 
00968     tv.tv_sec = timeout / 1000;
00969     tv.tv_usec = timeout % 1000;
00970 
00971     if( select( m_socket + 1, &fds, 0, 0, timeout == -1 ? 0 : &tv ) >= 0 )
00972     {
00973       return FD_ISSET( m_socket, &fds ) ? true : false;
00974     }
00975     return false;
00976   }
00977 
00978   ConnectionError Connection::recv( int timeout )
00979   {
00980     if( m_cancel )
00981     {
00982       ConnectionError e = m_disconnect;
00983       cleanup();
00984       return e;
00985     }
00986 
00987     if( m_socket == -1 )
00988       return ConnNotConnected;
00989 
00990     if( !m_fdRequested && !dataAvailable( timeout ) )
00991     {
00992         return ConnNoError;
00993     }
00994 
00995     // optimize(?): recv returns the size. set size+1 = \0
00996     memset( m_buf, '\0', m_bufsize + 1 );
00997     int size = 0;
00998 #ifdef HAVE_TLS
00999     if( m_secure )
01000     {
01001       size = tls_recv( m_buf, m_bufsize );
01002     }
01003     else
01004 #endif
01005     {
01006 #ifdef SKYOS
01007       size = ::recv( m_socket, (unsigned char*)m_buf, m_bufsize, 0 );
01008 #else
01009       size = ::recv( m_socket, m_buf, m_bufsize, 0 );
01010 #endif
01011     }
01012 
01013     if( size < 0 )
01014     {
01015       // error
01016       return ConnIoError;
01017     }
01018     else if( size == 0 )
01019     {
01020       // connection closed
01021       return ConnUserDisconnected;
01022     }
01023 
01024     std::string buf;
01025     buf.assign( m_buf, size );
01026     if( m_compression && m_enableCompression )
01027       buf = m_compression->decompress( buf );
01028 
01029     Parser::ParserState ret = m_parser->feed( buf );
01030     if( ret != Parser::PARSER_OK )
01031     {
01032       cleanup();
01033       switch( ret )
01034       {
01035         case Parser::PARSER_BADXML:
01036           m_logInstance.log( LogLevelError, LogAreaClassConnection, "XML parse error" );
01037           break;
01038         case Parser::PARSER_NOMEM:
01039           m_logInstance.log( LogLevelError, LogAreaClassConnection, "memory allocation error" );
01040           break;
01041         default:
01042           m_logInstance.log( LogLevelError, LogAreaClassConnection, "unexpected error" );
01043           break;
01044       }
01045       //printf( "buffer data: %s\n", buf.c_str() );
01046       return ConnIoError;
01047     }
01048 
01049     return ConnNoError;
01050   }
01051 
01052   ConnectionError Connection::receive()
01053   {
01054     if( m_socket == -1 || !m_parser )
01055       return ConnNotConnected;
01056 
01057     while( !m_cancel )
01058     {
01059       ConnectionError r = recv( 1 );
01060       if( r != ConnNoError )
01061         return r;
01062     }
01063     cleanup();
01064 
01065     return m_disconnect;
01066   }
01067 
01068   bool Connection::send( const std::string& data )
01069   {
01070     if( data.empty() || ( m_socket == -1 ) )
01071       return false;
01072 
01073     std::string xml;
01074     if( m_compression && m_enableCompression )
01075       xml = m_compression->compress( data );
01076     else
01077       xml = data;
01078 
01079 #ifdef HAVE_TLS
01080     if( m_secure )
01081     {
01082       size_t len = xml.length();
01083       if( tls_send( xml.c_str (), len ) == false )
01084         return false;
01085     }
01086     else
01087 #endif
01088     {
01089       size_t num = 0;
01090       size_t len = xml.length();
01091       while( num < len )
01092       {
01093 #ifdef SKYOS
01094         int sent = ::send( m_socket, (unsigned char*)(xml.c_str()+num), len - num, 0 );
01095 #else
01096         int sent = ::send( m_socket, (xml.c_str()+num), len - num, 0 );
01097 #endif
01098         if( sent == -1 )
01099           return false;
01100 
01101         num += sent;
01102       }
01103     }
01104 
01105     return true;
01106   }
01107 
01108   void Connection::cleanup()
01109   {
01110 #ifdef HAVE_TLS
01111     if( m_secure )
01112     {
01113       tls_cleanup();
01114     }
01115 #endif
01116 
01117     if( m_socket != -1 )
01118     {
01119 #ifdef WIN32
01120       closesocket( m_socket );
01121 #else
01122       close( m_socket );
01123 #endif
01124       m_socket = -1;
01125     }
01126     m_state = StateDisconnected;
01127     m_disconnect = ConnNoError;
01128     m_enableCompression = false;
01129     m_secure = false;
01130     m_cancel = true;
01131     m_fdRequested = false;
01132   }
01133 
01134 }

Generated on Sun Sep 24 21:57:31 2006 for gloox by  doxygen 1.4.7