diff --git a/include/common/LHashTable.h b/include/common/LHashTable.h new file mode 100644 --- /dev/null +++ b/include/common/LHashTable.h @@ -0,0 +1,784 @@ +/* + More modern take on the GHashTbl I had been using for a while. + Moved the key management into a parameter class. All the key pooling + is also now managed by the param class rather than the hash table itself. +*/ +#ifndef _LHashTbl_H_ +#define _LHashTbl_H_ + +#include +#include "GMem.h" +#include "GArray.h" +#include "GString.h" +#include "LgiClass.h" + +#ifndef LHASHTBL_MAX_SIZE +#define LHASHTBL_MAX_SIZE (64 << 10) +#endif + +#define HASH_TABLE_SHRINK_THRESHOLD 15 +#define HASH_TABLE_GROW_THRESHOLD 50 + +template +class IntKey +{ +public: + typedef T Type; + + T NullKey; + + IntKey() + { + NullKey = DefaultNull; + } + + void EmptyKeys() {} + uint32 Hash(T k) { return (uint32)k; } + T CopyKey(T a) { return a; } + size_t SizeKey(T a) { return sizeof(a); } + void FreeKey(T &a) { a = NullKey; } + bool CmpKey(T a, T b) + { + return a == b; + } + size_t TotalSize() { return 0; } +}; + +template +class PtrKey +{ +public: + typedef T Type; + + T NullKey; + + PtrKey() + { + NullKey = DefaultNull; + } + + void EmptyKeys() {} + uint32 Hash(T k) { return (uint32)(((size_t)k)/31); } + T CopyKey(T a) { return a; } + size_t SizeKey(T a) { return sizeof(a); } + void FreeKey(T &a) { a = NullKey; } + bool CmpKey(T a, T b) + { + return a == b; + } + size_t TotalSize() { return 0; } +}; + +template +class StrKey +{ +public: + typedef T *Type; + + T *NullKey; + + StrKey() + { + NullKey = DefaultNull; + } + + void EmptyKeys() {} + uint32 Hash(T *k) { return LHash(k, Strlen(k), CaseSen); } + T *CopyKey(T *a) { return Strdup(a); } + size_t SizeKey(T *a) { return (Strlen(a)+1)*sizeof(*a); } + void FreeKey(T *&a) { if (a) delete [] a; a = NullKey; } + bool CmpKey(T *a, T *b) { return !(CaseSen ? Strcmp(a, b) : Stricmp(a, b)); } + size_t TotalSize() { return 0; } +}; + +template +class KeyPool +{ +protected: + struct Buf : public GArray + { + size_t Used; + Buf(size_t Sz = 0) { this->Length(Sz); } + size_t Free() { return this->Length() - Used; } + }; + + GArray Mem; + Buf *GetMem(size_t Sz) + { + if (!Mem.Length() || Mem.Last().Free() < Sz) + Mem.New().Length(PoolSize); + return Mem.Last().Free() >= Sz ? &Mem.Last() : NULL; + } + +public: + const int DefaultPoolSize = (64 << 10) / sizeof(T); + int PoolSize; + + KeyPool() + { + PoolSize = BlockSize ? BlockSize : DefaultPoolSize; + } + + void EmptyKeys() + { + Mem.Length(0); + } + + size_t TotalSize() + { + size_t s = 0; + for (auto &b : Mem) + s += sizeof(Buf) + (b.Length() * sizeof(T)); + return s; + } +}; + +template +class ConstStrKey +{ +public: + typedef const T *Type; + + const T *NullKey; + + ConstStrKey() + { + NullKey = DefaultNull; + } + + void EmptyKeys() {} + uint32 Hash(const T *k) { return LHash(k, Strlen(k), CaseSen); } + T *CopyKey(const T *a) { return Strdup(a); } + size_t SizeKey(const T *a) { return (Strlen(a)+1)*sizeof(*a); } + void FreeKey(const T *&a) { if (a) delete [] a; a = NullKey; } + bool CmpKey(const T *a, const T *b) { return !(CaseSen ? Strcmp(a, b) : Stricmp(a, b)); } + size_t TotalSize() { return 0; } +}; + +template +class StrKeyPool : public KeyPool +{ +public: + typedef T *Type; + using Buf = typename KeyPool::Buf; + + T *NullKey; + + StrKeyPool() + { + NullKey = DefaultNull; + } + + uint32 Hash(T *k) { return LHash(k, Strlen(k), CaseSen); } + size_t SizeKey(T *a) { return (Strlen(a)+1)*sizeof(*a); } + bool CmpKey(T *a, T *b) { return !(CaseSen ? Strcmp(a, b) : Stricmp(a, b)); } + + T *CopyKey(T *a) + { + size_t Sz = Strlen(a) + 1; + Buf *m = this->GetMem(Sz); + if (!m) return NullKey; + T *r = m->AddressOf(m->Used); + memcpy(r, a, Sz*sizeof(*a)); + m->Used += Sz; + return r; + } + + void FreeKey(T *&a) + { + // Do nothing... memory is own by KeyPool + a = NullKey; + } +}; + +template +class ConstStrKeyPool : public KeyPool +{ +public: + typedef const T *Type; + using Buf = typename KeyPool::Buf; + + const T *NullKey; + + ConstStrKeyPool() + { + NullKey = DefaultNull; + } + + uint32 Hash(const T *k) { return LHash(k, Strlen(k), CaseSen); } + size_t SizeKey(const T *a) { return (Strlen(a)+1)*sizeof(*a); } + bool CmpKey(const T *a, const T *b) { return !(CaseSen ? Strcmp(a, b) : Stricmp(a, b)); } + + const T *CopyKey(const T *a) + { + size_t Sz = Strlen(a) + 1; + Buf *m = this->GetMem(Sz); + if (!m) return NullKey; + T *r = m->AddressOf(m->Used); + memcpy(r, a, Sz*sizeof(*a)); + m->Used += Sz; + return r; + } + + void FreeKey(const T *&a) + { + // Do nothing... memory is own by KeyPool + a = NullKey; + } +}; + +/// General hash table container for O(1) access to table data. +template +class LHashTbl : public KeyTrait +{ +public: + typedef typename KeyTrait::Type Key; + typedef LHashTbl HashTable; + const int DefaultSize = 256; + + struct Pair + { + Key key; + Value value; + }; + +protected: + Value NullValue; + + size_t Used; + size_t Size; + size_t MaxSize; + int Version; // This changes every time 'Table' is resized. + // It's used to invalidate iterators. + Pair *Table; + + int Percent() + { + return (int) (Used * 100 / Size); + } + + bool GetEntry(const Key k, ssize_t &Index, bool Debug = false) + { + if (k != this->NullKey && Table) + { + uint32 h = this->Hash(k); + + for (size_t i=0; iNullKey) + return false; + + if (this->CmpKey(Table[Index].key, k)) + return true; + } + } + + return false; + } + + bool Between(ssize_t Val, ssize_t Min, ssize_t Max) + { + if (Min <= Max) + { + // Not wrapped + return Val >= Min && Val <= Max; + } + else + { + // Wrapped + return Val <= Max || Val >= Min; + } + } + + + void InitializeTable(Pair *e, ssize_t len) + { + if (!e || len < 1) return; + while (len--) + { + e->key = this->NullKey; + e->value = NullValue; + e++; + } + } + +public: + /// Constructs the hash table + LHashTbl + ( + /// Sets the initial table size. Should be 2x your data set. + size_t size = 0, + /// The default empty value + Value nullvalue = (Value)0 + ) + { + Size = size; + NullValue = nullvalue; + Used = 0; + Version = 0; + MaxSize = LHASHTBL_MAX_SIZE; + // LgiAssert(Size <= MaxSize); + + if ((Table = new Pair[Size])) + { + InitializeTable(Table, Size); + } + } + + LHashTbl(const HashTable &init) + { + Size = init.Size; + NullValue = init.NullValue; + Used = 0; + Version = 0; + MaxSize = LHASHTBL_MAX_SIZE; + if ((Table = new Pair[Size])) + { + for (size_t i=0; iNullKey; + } + *this = init; + } + } + + /// Deletes the hash table removing all contents from memory + virtual ~LHashTbl() + { + Empty(); + DeleteArray(Table); + } + + Key GetNullKey() + { + return this->NullKey; + } + + /// Copy operator + HashTable &operator =(const HashTable &c) + { + if (IsOk() && c.IsOk()) + { + Empty(); + + this->NullKey = c.NullKey; + NullValue = c.NullValue; + + size_t Added = 0; + for (size_t i=0; iNullKey) + { + if (!Add(OldTable[i].key, OldTable[i].value)) + { + LgiAssert(0); + } + this->FreeKey(OldTable[i].key); + } + } + + Version++; + Status = true; + } + else + { + LgiAssert(Table != 0); + Table = OldTable; + Size = OldSize; + return false; + } + + DeleteArray(OldTable); + } + + return Status; + } + + /// Returns true if the object appears to be valid + bool IsOk() const + { + bool Status = + #ifndef __llvm__ + this != 0 && + #endif + Table != 0; + if (!Status) + { + #ifndef LGI_STATIC + LgiTrace("%s:%i - this=%p Table=%p Used=%i Size=%i\n", _FL, this, Table, Used, Size); + #endif + LgiAssert(0); + } + return Status; + } + + /// Gets the number of entries used + size_t Length() + { + return IsOk() ? Used : 0; + } + + /// Adds a value under a given key + bool Add + ( + /// The key to insert the value under + Key k, + /// The value to insert + Value v + ) + { + if (!Size) + SetSize(DefaultSize); + + if (IsOk() && + k == this->NullKey && + v == NullValue) + { + LgiAssert(!"Adding NULL key or value."); + return false; + } + + uint32 h = this->Hash(k); + + ssize_t Index = -1; + for (size_t i=0; iNullKey + || + this->CmpKey(Table[idx].key, k) + ) + { + Index = idx; + break; + } + } + + if (Index >= 0) + { + if (Table[Index].key == this->NullKey) + { + Table[Index].key = this->CopyKey(k); + Used++; + } + Table[Index].value = v; + + if (Percent() > HASH_TABLE_GROW_THRESHOLD) + { + SetSize(Size << 1); + } + return true; + } + + LgiAssert(!"Couldn't alloc space."); + return false; + } + + /// Deletes a value at 'key' + bool Delete + ( + /// The key of the value to delete + Key k, + /// Turns off resizing, in case your iterating over the hash table, + /// where resizing would invalidate the iterators. + bool NoResize = false + ) + { + ssize_t Index = -1; + if (GetEntry(k, Index)) + { + // Delete the entry + this->FreeKey(Table[Index].key); + Table[Index].value = NullValue; + Used--; + + // Bubble down any entries above the hole + ssize_t Hole = Index; + for (ssize_t i = (Index + 1) % Size; i != Index; i = (i + 1) % Size) + { + if (Table[i].key != this->NullKey) + { + uint32 Hsh = this->Hash(Table[i].key); + uint32 HashIndex = Hsh % Size; + + if (HashIndex != i && Between(Hole, HashIndex, i)) + { + // Do bubble + if (Table[Hole].key != this->NullKey) + { + LgiAssert(0); + } + memmove(Table + Hole, Table + i, sizeof(Table[i])); + InitializeTable(Table + i, 1); + Hole = i; + } + } + else + { + // Reached the end of entries that could have bubbled + break; + } + } + + // Check for auto-shrink limit + if (!NoResize && + Percent() < HASH_TABLE_SHRINK_THRESHOLD) + { + SetSize(Size >> 1); + } + + return true; + } + else + { + GetEntry(k, Index, true); + } + + return false; + } + + /// Returns the value at 'key' + Value Find(const Key k) + { + ssize_t Index = -1; + if (IsOk() && GetEntry(k, Index)) + { + return Table[Index].value; + } + + return NullValue; + } + + /// Returns the Key at 'val' + Key FindKey(const Value val) + { + if (IsOk()) + { + Pair *c = Table; + Pair *e = Table + Size; + while (c < e) + { + if (CmpKey(c->value, val)) + { + return c->key; + } + c++; + } + } + + return this->NullKey; + } + + /// Removes all key/value pairs from memory + void Empty() + { + if (!IsOk()) + return; + + for (size_t i=0; iNullKey) + { + this->FreeKey(Table[i].key); + LgiAssert(Table[i].key == this->NullKey); + } + Table[i].value = NullValue; + } + + Used = 0; + this->EmptyKeys(); + } + + /// Returns the amount of memory in use by the hash table. + int64 Sizeof() + { + int64 Sz = sizeof(*this); + + Sz += Sz * sizeof(Pair); + + int64 KeySize = 0; + size_t Total = KeyTrait::TotalSize(); + if (Total) + { + KeySize += Total; + } + else + { + int Keys = 0; + for (size_t i=0; iNullKey) + { + Keys++; + KeySize += this->SizeKey(Table[i].key); + } + } + } + + return Sz + KeySize; + } + + /// Deletes values as objects + void DeleteObjects() + { + for (size_t i=0; iNullKey) + this->FreeKey(Table[i].key); + + if (Table[i].value != NullValue) + DeleteObj(Table[i].value); + } + + Used = 0; + } + + /// Deletes values as arrays + void DeleteArrays() + { + for (size_t i=0; iNullKey) + this->FreeKey(Table[i].key); + + if (Table[i].value != NullValue) + DeleteArray(Table[i].value); + } + + Used = 0; + } + + /// Swaps the objects + void Swap(LHashTbl &h) + { + LSwap(this->NullKey, h.NullKey); + LSwap(NullValue, h.NullValue); + LSwap(Used, h.Used); + LSwap(Size, h.Size); + LSwap(MaxSize, h.MaxSize); + LSwap(Version, h.Version); + LSwap(Table, h.Table); + } + + struct PairIterator + { + LHashTbl *t; + ssize_t Idx; + int Version; + + public: + PairIterator(LHashTbl *tbl, ssize_t i) + { + t = tbl; + Version = t->Version; + Idx = i; + if (Idx < 0) + Next(); + } + + bool operator !=(const PairIterator &it) const + { + bool Eq = t == it.t && + Idx == it.Idx; + return !Eq; + } + + PairIterator &Next() + { + if (t->IsOk()) + { + if (Version != t->Version) + { + #ifndef LGI_UNIT_TESTS + LgiAssert(!"Iterator invalidated"); + #endif + *this = t->end(); + } + else + { + while (++Idx < (ssize_t)t->Size) + { + if (t->Table[Idx].key != t->NullKey) + break; + } + } + } + + return *this; + } + + PairIterator &operator ++() { return Next(); } + PairIterator &operator ++(int) { return Next(); } + + Pair &operator *() + { + LgiAssert( Idx >= 0 && + Idx < (ssize_t)t->Size && + t->Table[Idx].key != t->NullKey); + return t->Table[Idx]; + } + }; + + PairIterator begin() + { + return PairIterator(this, -1); + } + + PairIterator end() + { + return PairIterator(this, Size); + } +}; + +#endif + diff --git a/include/common/LgiClass.h b/include/common/LgiClass.h --- a/include/common/LgiClass.h +++ b/include/common/LgiClass.h @@ -1,299 +1,344 @@ #ifndef _LGI_CLASS_H_ #define _LGI_CLASS_H_ #include "LgiInc.h" #include "LgiDefs.h" // Virtual input classes class GKey; class GMouse; // General GUI classes class GTarget; class GComponent; class GEvent; class GId; class GApp; class GWindow; class GWin32Class; class GView; class GLayout; class GFileSelect; class GFindReplace; class GSubMenu; class GMenuItem; class GMenu; class GToolBar; class GToolButton; class GSplitter; class GStatusPane; class GStatusBar; class GToolColour; class GScrollBar; class GImageList; class GDialog; // General objects class LgiClass GBase { char *_Name8; char16 *_Name16; public: GBase(); virtual ~GBase(); virtual char *Name(); virtual bool Name(const char *n); virtual char16 *NameW(); virtual bool NameW(const char16 *n); }; #define AssignFlag(f, bit, to) if (to) f |= bit; else f &= ~(bit) /// Sets the output stream for the LgiTrace statement. By default the stream output /// is to .txt in the executables folder or $LSP_APP_ROOT\.txt if /// that is not writable. If the stream is set to something then normal file output is /// directed to the specified stream instead. LgiFunc void LgiTraceSetStream(class GStreamI *stream); /// Gets the log file path LgiFunc bool LgiTraceGetFilePath(char *LogPath, int BufLen); /// Writes a debug statement to a output stream, or if not defined with LgiTraceSetStream /// then to a log file (see LgiTraceSetStream for details) /// /// Default path is ./.txt relative to the executable. /// Fallback path is LgiGetSystemPath(LSP_APP_ROOT). LgiFunc void LgiTrace(const char *Format, ...); #ifndef LGI_STATIC /// Same as LgiTrace but writes a stack trace as well. LgiFunc void LgiStackTrace(const char *Format, ...); #endif /// General user interface event class LgiClass GUiEvent { public: int Flags; GUiEvent() { Flags = 0; } virtual ~GUiEvent() {} virtual void Trace(const char *Msg) {} /// The key or mouse button was being pressed. false on the up-click. bool Down() { return TestFlag(Flags, LGI_EF_DOWN); } /// The mouse button was double clicked. bool Double() { return TestFlag(Flags, LGI_EF_DOUBLE); } /// A ctrl button was held down during the event bool Ctrl() { return TestFlag(Flags, LGI_EF_CTRL); } /// A alt button was held down during the event bool Alt() { return TestFlag(Flags, LGI_EF_ALT); } /// A shift button was held down during the event bool Shift() { return TestFlag(Flags, LGI_EF_SHIFT); } /// The system key was held down (windows key / apple key etc) bool System() { return TestFlag(Flags, LGI_EF_SYSTEM); } // Set void Down(bool i) { AssignFlag(Flags, LGI_EF_DOWN, i); } void Double(bool i) { AssignFlag(Flags, LGI_EF_DOUBLE, i); } void Ctrl(bool i) { AssignFlag(Flags, LGI_EF_CTRL, i); } void Alt(bool i) { AssignFlag(Flags, LGI_EF_ALT, i); } void Shift(bool i) { AssignFlag(Flags, LGI_EF_SHIFT, i); } void System(bool i) { AssignFlag(Flags, LGI_EF_SYSTEM, i); } bool Modifier() { #if defined(BEOS) return Alt(); #elif defined(MAC) return System(); // "Apple" key #else // win32 and linux return Ctrl(); #endif } void SetModifer(uint32 modifierKeys) { #if defined(MAC) #if defined COCOA #warning FIXME #else System(modifierKeys & cmdKey); Shift(modifierKeys & shiftKey); Alt(modifierKeys & optionKey); Ctrl(modifierKeys & controlKey); #endif #elif defined(__GTK_H__) System(modifierKeys & Gtk::GDK_MOD4_MASK); Shift(modifierKeys & Gtk::GDK_SHIFT_MASK); Alt(modifierKeys & Gtk::GDK_MOD1_MASK); Ctrl(modifierKeys & Gtk::GDK_CONTROL_MASK); #endif } }; /// All the information related to a keyboard event class LgiClass GKey : public GUiEvent { public: /// The virtual code for key char16 vkey; /// The unicode character for the key char16 c16; /// OS Specific uint32 Data; /// True if this is a standard character (ie not a control key) bool IsChar; GKey() { vkey = 0; c16 = 0; Data = 0; IsChar = 0; } GKey(int vkey, int flags); void Trace(const char *Msg) { LgiTrace("%s GKey vkey=%i(0x%x) c16=%i(%c) IsChar=%i down=%i ctrl=%i alt=%i sh=%i sys=%i\n", Msg ? Msg : (char*)"", vkey, vkey, c16, c16 >= ' ' && c16 < 127 ? c16 : '.', IsChar, Down(), Ctrl(), Alt(), Shift(), System()); } /// Returns the character in the right case... char16 GetChar() { if (Shift() ^ TestFlag(Flags, LGI_EF_CAPS_LOCK)) { return (c16 >= 'a' && c16 <= 'z') ? c16 - 'a' + 'A' : c16; } else { return (c16 >= 'A' && c16 <= 'Z') ? c16 - 'A' + 'a' : c16; } } /// \returns true if this event should show a context menu bool IsContextMenu(); }; /// \brief All the parameters of a mouse click event /// /// The parent class GUiEvent keeps information about whether it was a Down() /// or Double() click. You can also query whether the Alt(), Ctrl() or Shift() /// keys were pressed at the time the event occured. /// /// To get the position of the mouse in screen co-ordinates you can either use /// GView::GetMouse() and pass true in the 'ScreenCoords' parameter. Or you can /// construct a GdcPt2 out of the x,y fields of this class and use GView::PointToScreen() /// to map the point to screen co-ordinates. class LgiClass GMouse : public GUiEvent { public: /// Receiving view class GViewI *Target; /// True if specified in view coordinates, false if in screen coords bool ViewCoords; /// The x co-ordinate of the mouse relitive to the current view int x; /// The y co-ordinate of the mouse relitive to the current view int y; GMouse() { Target = 0; ViewCoords = true; x = y = 0; } void Trace(const char *Msg) { LgiTrace("%s GMouse pos=%i,%i view=%i btns=%i/%i/%i dwn=%i dbl=%i " "ctrl=%i alt=%i sh=%i sys=%i\n", Msg ? Msg : (char*)"", x, y, ViewCoords, Left(), Middle(), Right(), Down(), Double(), Ctrl(), Alt(), Shift(), System()); } /// True if the left mouse button was clicked bool Left() { return TestFlag(Flags, LGI_EF_LEFT); } /// True if the middle mouse button was clicked bool Middle() { return TestFlag(Flags, LGI_EF_MIDDLE); } /// True if the right mouse button was clicked bool Right() { return TestFlag(Flags, LGI_EF_RIGHT); } /// True if the mouse event is a move, false for a click event. bool IsMove() { return TestFlag(Flags, LGI_EF_MOVE); } /// Sets the left button flag void Left(bool i) { AssignFlag(Flags, LGI_EF_LEFT, i); } /// Sets the middle button flag void Middle(bool i) { AssignFlag(Flags, LGI_EF_MIDDLE, i); } /// Sets the right button flag void Right(bool i) { AssignFlag(Flags, LGI_EF_RIGHT, i); } /// Sets the move flag void IsMove(bool i) { AssignFlag(Flags, LGI_EF_MOVE, i); } /// Converts to screen coordinates bool ToScreen(); /// Converts to local coordinates bool ToView(); /// \returns true if this event should show a context menu bool IsContextMenu(); void SetButton(uint32 Btn) { #ifdef MAC #if defined COCOA #warning FIXME #else Left(Btn == kEventMouseButtonPrimary); Right(Btn == kEventMouseButtonSecondary); Middle(Btn == kEventMouseButtonTertiary); #endif #endif } }; #include "GAutoPtr.h" /// Holds information pertaining to an application class GAppInfo { public: /// The path to the executable for the app GAutoString Path; /// Plain text name for the app GAutoString Name; /// A path to an icon to display for the app GAutoString Icon; /// The params to call the app with GAutoString Params; }; +template +RESULT LHash(const CHAR *v, ssize_t l, bool Case) +{ + RESULT h = 0; + + if (Case) + { + // case sensitive + if (l > 0) + { + while (l--) + h = (h << 5) - h + *v++; + } + else + { + for (; *v; v ++) + h = (h << 5) - h + *v; + } + } + else + { + // case insensitive + CHAR c; + if (l > 0) + { + while (l--) + { + c = tolower(*v); + v++; + h = (h << 5) - h + c; + } + } + else + { + for (; *v; v++) + { + c = tolower(*v); + h = (h << 5) - h + c; + } + } + } + + return h; +} + #endif diff --git a/include/common/OpenSSLSocket.h b/include/common/OpenSSLSocket.h --- a/include/common/OpenSSLSocket.h +++ b/include/common/OpenSSLSocket.h @@ -1,74 +1,76 @@ #ifndef _OPEN_SSL_SOCKET_H_ #define _OPEN_SSL_SOCKET_H_ #include "GLibraryUtils.h" #include "LCancel.h" // If you get a compile error on Linux: // sudo apt-get install libssl-dev #include "openssl/bio.h" #include "openssl/ssl.h" #include "openssl/err.h" #define SslSocket_LogFile "LogFile" #define SslSocket_LogFormat "LogFmt" class SslSocket : public GSocketI, virtual public GDom { friend class OpenSSL; struct SslSocketPriv *d; LMutex Lock; BIO *Bio; SSL *Ssl; GString ErrMsg; // Local stuff virtual void Log(const char *Str, ssize_t Bytes, SocketMsgType Type); void SslError(const char *file, int line, const char *Msg); GStream *GetLogStream(); void DebugTrace(const char *fmt, ...); public: static bool DebugLogging; SslSocket(GStreamI *logger = NULL, GCapabilityClient *caps = NULL, bool SslOnConnect = false, bool RawLFCheck = false); ~SslSocket(); void SetLogger(GStreamI *logger); void SetSslOnConnect(bool b); + LCancel *GetCancel(); + void SetCancel(LCancel *c); // Socket OsSocket Handle(OsSocket Set = INVALID_SOCKET); bool IsOpen(); int Open(const char *HostAddr, int Port); int Close(); bool Listen(int Port = 0); void OnError(int ErrorCode, const char *ErrorDescription); void OnInformation(const char *Str); int GetTimeout(); void SetTimeout(int ms); ssize_t Write(const void *Data, ssize_t Len, int Flags = 0); ssize_t Read(void *Data, ssize_t Len, int Flags = 0); void OnWrite(const char *Data, ssize_t Len); void OnRead(char *Data, ssize_t Len); bool IsReadable(int TimeoutMs = 0); bool IsWritable(int TimeoutMs = 0); bool IsBlocking(); void IsBlocking(bool block); bool SetVariant(const char *Name, GVariant &Val, char *Arr = NULL); bool GetVariant(const char *Name, GVariant &Val, char *Arr = NULL); GStreamI *Clone(); const char *GetErrorString(); }; extern bool StartSSL(GAutoString &ErrorMsg, SslSocket *Sock); extern void EndSSL(); -#endif \ No newline at end of file +#endif diff --git a/src/common/INet/OpenSSLSocket.cpp b/src/common/INet/OpenSSLSocket.cpp --- a/src/common/INet/OpenSSLSocket.cpp +++ b/src/common/INet/OpenSSLSocket.cpp @@ -1,1425 +1,1450 @@ /*hdr ** FILE: OpenSSLSocket.cpp ** AUTHOR: Matthew Allen ** DATE: 24/9/2004 ** DESCRIPTION: Open SSL wrapper socket ** ** Copyright (C) 2004-2014, Matthew Allen ** fret@memecode.com ** */ #include #ifdef WINDOWS #pragma comment(lib,"Ws2_32.lib") #else #include #endif #include "Lgi.h" #include "OpenSSLSocket.h" #ifdef WIN32 #include #endif #include "GToken.h" #include "GVariant.h" #include "INet.h" -#if OPENSSL_VERSION_NUMBER >= 0x10100000L -#error "SSL library too new." -#endif - #define PATH_OFFSET "../" #ifdef WIN32 -#define SSL_LIBRARY "ssleay32" -#define EAY_LIBRARY "libeay32" + #if OPENSSL_VERSION_NUMBER >= 0x10100000L + #ifdef _WIN64 + #define SSL_LIBRARY "libssl-1_1-x64" + #define EAY_LIBRARY "libcrypto-1_1-x64" + #else // 32bit + #define SSL_LIBRARY "libssl-1_1" + #define EAY_LIBRARY "libcrypto-1_1" + #endif + #else + #define SSL_LIBRARY "ssleay32" + #define EAY_LIBRARY "libeay32" + #endif #else #define SSL_LIBRARY "libssl" #endif #define HasntTimedOut() ((To < 0) || (LgiCurrentTime() - Start < To)) static const char* MinimumVersion = "1.0.1g"; void SSL_locking_function(int mode, int n, const char *file, int line); unsigned long SSL_id_function(); class LibSSL : public GLibrary { public: LibSSL() { char p[MAX_PATH]; #if defined MAC if (LgiGetExeFile(p, sizeof(p))) { LgiMakePath(p, sizeof(p), p, "Contents/MacOS/libssl.1.0.0.dylib"); if (FileExists(p)) { Load(p); } } if (!IsLoaded()) { Load("/opt/local/lib/" SSL_LIBRARY); } #elif defined LINUX if (LgiGetExePath(p, sizeof(p))) { LgiMakePath(p, sizeof(p), p, "libssl.so"); if (FileExists(p)) { LgiTrace("%s:%i - loading SSL library '%s'\n", _FL, p); Load(p); } } #endif if (!IsLoaded()) Load(SSL_LIBRARY); if (!IsLoaded()) { LgiGetExePath(p, sizeof(p)); LgiMakePath(p, sizeof(p), p, PATH_OFFSET "../OpenSSL"); #ifdef WIN32 - char old[300]; + char old[MAX_PATH]; FileDev->GetCurrentFolder(old, sizeof(old)); FileDev->SetCurrentFolder(p); #endif Load(SSL_LIBRARY); #ifdef WIN32 FileDev->SetCurrentFolder(old); #endif } } #if OPENSSL_VERSION_NUMBER >= 0x10100000L DynFunc0(int, OPENSSL_library_init); DynFunc0(int, OPENSSL_load_error_strings); - DynFunc2(int, OPENSSL_init_crypto, uint64_t, opts, const OPENSSL_INIT_SETTINGS *, settings); DynFunc2(int, OPENSSL_init_ssl, uint64_t, opts, const OPENSSL_INIT_SETTINGS *, settings); + DynFunc0(const SSL_METHOD *, TLS_method); + DynFunc0(const SSL_METHOD *, TLS_server_method); + DynFunc0(const SSL_METHOD *, TLS_client_method); #else DynFunc0(int, SSL_library_init); DynFunc0(int, SSL_load_error_strings); + DynFunc0(SSL_METHOD*, SSLv23_client_method); + DynFunc0(SSL_METHOD*, SSLv23_server_method); #endif DynFunc1(int, SSL_open, SSL*, s); DynFunc1(int, SSL_connect, SSL*, s); DynFunc4(long, SSL_ctrl, SSL*, ssl, int, cmd, long, larg, void*, parg); DynFunc1(int, SSL_shutdown, SSL*, s); DynFunc1(int, SSL_free, SSL*, ssl); DynFunc1(int, SSL_get_fd, const SSL *, s); DynFunc2(int, SSL_set_fd, SSL*, s, int, fd); DynFunc1(SSL*, SSL_new, SSL_CTX*, ctx); DynFunc1(BIO*, BIO_new_ssl_connect, SSL_CTX*, ctx); DynFunc1(X509*, SSL_get_peer_certificate, SSL*, s); DynFunc1(int, SSL_set_connect_state, SSL*, s); DynFunc1(int, SSL_set_accept_state, SSL*, s); DynFunc2(int, SSL_get_error, SSL*, s, int, ret_code); DynFunc3(int, SSL_set_bio, SSL*, s, BIO*, rbio, BIO*, wbio); DynFunc3(int, SSL_write, SSL*, ssl, const void*, buf, int, num); DynFunc3(int, SSL_read, SSL*, ssl, const void*, buf, int, num); DynFunc1(int, SSL_pending, SSL*, ssl); DynFunc1(BIO *, SSL_get_rbio, const SSL *, s); DynFunc1(int, SSL_accept, SSL *, ssl); - DynFunc0(SSL_METHOD*, SSLv23_client_method); - DynFunc0(SSL_METHOD*, SSLv23_server_method); - - DynFunc1(SSL_CTX*, SSL_CTX_new, SSL_METHOD*, meth); + DynFunc1(SSL_CTX*, SSL_CTX_new, const SSL_METHOD*, meth); DynFunc3(int, SSL_CTX_load_verify_locations, SSL_CTX*, ctx, const char*, CAfile, const char*, CApath); DynFunc3(int, SSL_CTX_use_certificate_file, SSL_CTX*, ctx, const char*, file, int, type); DynFunc3(int, SSL_CTX_use_PrivateKey_file, SSL_CTX*, ctx, const char*, file, int, type); DynFunc1(int, SSL_CTX_check_private_key, const SSL_CTX*, ctx); DynFunc1(int, SSL_CTX_free, SSL_CTX*, ctx); #ifdef WIN32 // If this is freaking you out then good... openssl-win32 ships // in 2 DLL's and on Linux everything is 1 shared object. Thus // the code reflects that. }; class LibEAY : public GLibrary { public: LibEAY() : GLibrary(EAY_LIBRARY) { if (!IsLoaded()) { char p[300]; LgiGetExePath(p, sizeof(p)); LgiMakePath(p, sizeof(p), p, PATH_OFFSET "../OpenSSL"); #ifdef WIN32 - char old[300]; + char old[MAX_PATH]; FileDev->GetCurrentFolder(old, sizeof(old)); FileDev->SetCurrentFolder(p); #endif Load("libeay32"); #ifdef WIN32 FileDev->SetCurrentFolder(old); #endif } } #endif typedef void (*locking_callback)(int mode,int type, const char *file,int line); typedef unsigned long (*id_callback)(); - DynFunc1(const char *, SSLeay_version, int, type); - DynFunc1(BIO*, BIO_new, BIO_METHOD*, type); DynFunc0(BIO_METHOD*, BIO_s_socket); DynFunc0(BIO_METHOD*, BIO_s_mem); DynFunc1(BIO*, BIO_new_connect, char *, host_port); DynFunc4(long, BIO_ctrl, BIO*, bp, int, cmd, long, larg, void*, parg); DynFunc4(long, BIO_int_ctrl, BIO *, bp, int, cmd, long, larg, int, iarg); DynFunc3(int, BIO_read, BIO*, b, void*, data, int, len); DynFunc3(int, BIO_write, BIO*, b, const void*, data, int, len); DynFunc1(int, BIO_free, BIO*, a); DynFunc1(int, BIO_free_all, BIO*, a); DynFunc2(int, BIO_test_flags, const BIO *, b, int, flags); DynFunc0(int, ERR_load_BIO_strings); #if OPENSSL_VERSION_NUMBER < 0x10100000L DynFunc0(int, ERR_free_strings); DynFunc0(int, EVP_cleanup); DynFunc0(int, OPENSSL_add_all_algorithms_noconf); DynFunc1(int, CRYPTO_set_locking_callback, locking_callback, func); DynFunc1(int, CRYPTO_set_id_callback, id_callback, func); DynFunc0(int, CRYPTO_num_locks); + DynFunc1(const char *, SSLeay_version, int, type); + #else + DynFunc2(int, OPENSSL_init_crypto, uint64_t, opts, const OPENSSL_INIT_SETTINGS *, settings); + DynFunc1(const char *, OpenSSL_version, int, type); #endif DynFunc1(const char *, ERR_lib_error_string, unsigned long, e); DynFunc1(const char *, ERR_func_error_string, unsigned long, e); DynFunc1(const char *, ERR_reason_error_string, unsigned long, e); DynFunc1(int, ERR_print_errors, BIO *, bp); DynFunc3(char*, X509_NAME_oneline, X509_NAME*, a, char*, buf, int, size); DynFunc1(X509_NAME*, X509_get_subject_name, X509*, a); DynFunc2(char*, ERR_error_string, unsigned long, e, char*, buf); DynFunc0(unsigned long, ERR_get_error); }; typedef GArray SslVer; SslVer ParseSslVersion(const char *v) { GToken t(v, "."); SslVer out; for (unsigned i=0; i(SslVer &a, SslVer &b) { return CompareSslVersion(a, b) > 0; } static const char *FileLeaf(const char *f) { const char *l = strrchr(f, DIR_CHAR); return l ? l + 1 : f; } #undef _FL #define _FL FileLeaf(__FILE__), __LINE__ class OpenSSL : #ifdef WINDOWS public LibEAY, #endif public LibSSL { SSL_CTX *Server; public: SSL_CTX *Client; GArray Locks; GAutoString ErrorMsg; bool IsLoaded() { return LibSSL::IsLoaded() #ifdef WINDOWS && LibEAY::IsLoaded() #endif ; } bool InitLibrary(SslSocket *sock) { GStringPipe Err; GArray Ver; GArray MinimumVer = ParseSslVersion(MinimumVersion); GToken t; int Len = 0; const char *v = NULL; if (!IsLoaded()) { - Err.Print("%s:%i - SSL libraries missing.\n", _FL); + #ifdef EAY_LIBRARY + Err.Print("%s:%i - SSL libraries missing (%s, %s)\n", _FL, SSL_LIBRARY, EAY_LIBRARY); + #else + Err.Print("%s:%i - SSL library missing (%s)\n", _FL, SSL_LIBRARY); + #endif goto OnError; } SSL_library_init(); SSL_load_error_strings(); ERR_load_BIO_strings(); OpenSSL_add_all_algorithms(); Len = CRYPTO_num_locks(); Locks.Length(Len); CRYPTO_set_locking_callback(SSL_locking_function); CRYPTO_set_id_callback(SSL_id_function); v = SSLeay_version(SSLEAY_VERSION); if (!v) { Err.Print("%s:%i - SSLeay_version failed.\n", _FL); goto OnError; } t.Parse(v, " "); if (t.Length() < 2) { Err.Print("%s:%i - SSLeay_version: no version\n", _FL); goto OnError; } Ver = ParseSslVersion(t[1]); if (Ver.Length() < 3) { Err.Print("%s:%i - SSLeay_version: not enough tokens\n", _FL); goto OnError; } if (Ver < MinimumVer) { #if WINDOWS char FileName[MAX_PATH] = ""; DWORD r = GetModuleFileNameA(LibEAY::Handle(), FileName, sizeof(FileName)); #endif Err.Print("%s:%i - SSL version '%s' is too old (minimum '%s')\n" #if WINDOWS "%s\n" #endif , _FL, t[1], MinimumVersion #if WINDOWS ,FileName #endif ); goto OnError; } Client = SSL_CTX_new(SSLv23_client_method()); if (!Client) { long e = ERR_get_error(); char *Msg = ERR_error_string(e, 0); Err.Print("%s:%i - SSL_CTX_new(client) failed with '%s' (%i)\n", _FL, Msg, e); goto OnError; } return true; OnError: ErrorMsg.Reset(Err.NewStr()); if (sock) sock->DebugTrace("%s", ErrorMsg.Get()); return false; } OpenSSL() { Client = NULL; Server = NULL; } ~OpenSSL() { if (Client) { SSL_CTX_free(Client); Client = NULL; } if (Server) { SSL_CTX_free(Server); Server = NULL; } Locks.DeleteObjects(); } SSL_CTX *GetServer(SslSocket *sock, const char *CertFile, const char *KeyFile) { if (!Server) { Server = SSL_CTX_new(SSLv23_server_method()); if (Server) { if (CertFile) SSL_CTX_use_certificate_file(Server, CertFile, SSL_FILETYPE_PEM); if (KeyFile) SSL_CTX_use_PrivateKey_file(Server, KeyFile, SSL_FILETYPE_PEM); if (!SSL_CTX_check_private_key(Server)) { LgiAssert(0); } } else { long e = ERR_get_error(); char *Msg = ERR_error_string(e, 0); GStringPipe p; p.Print("%s:%i - SSL_CTX_new(server) failed with '%s' (%i)\n", _FL, Msg, e); ErrorMsg.Reset(p.NewStr()); sock->DebugTrace("%s", ErrorMsg.Get()); } } return Server; } bool IsOk(SslSocket *sock) { bool Loaded = #ifdef WIN32 LibSSL::IsLoaded() && LibEAY::IsLoaded(); #else IsLoaded(); #endif if (Loaded) return true; // Try and load again... cause the library can be provided by install on demand. #ifdef WIN32 Loaded = LibSSL::Load(SSL_LIBRARY) && LibEAY::Load(EAY_LIBRARY); #else Loaded = Load(SSL_LIBRARY); #endif if (Loaded) InitLibrary(sock); return Loaded; } }; static OpenSSL *Library = 0; #if 0 #define SSL_DEBUG_LOCKING #endif void SSL_locking_function(int mode, int n, const char *file, int line) { LgiAssert(Library != NULL); if (Library) { if (!Library->Locks[n]) { #ifdef SSL_DEBUG_LOCKING LgiTrace("SSL[%i] create\n", n); #endif Library->Locks[n] = new LMutex; } #ifdef SSL_DEBUG_LOCKING LgiTrace("SSL[%i] lock=%i, unlock=%i, re=%i, wr=%i (mode=0x%x, cnt=%i, thr=0x%x, %s:%i)\n", n, TestFlag(mode, CRYPTO_LOCK), TestFlag(mode, CRYPTO_UNLOCK), TestFlag(mode, CRYPTO_READ), TestFlag(mode, CRYPTO_WRITE), mode, Library->Locks[n]->GetCount(), LgiGetCurrentThread(), file, line); #endif if (mode & CRYPTO_LOCK) Library->Locks[n]->Lock((char*)file, line); else if (mode & CRYPTO_UNLOCK) Library->Locks[n]->Unlock(); } } unsigned long SSL_id_function() { return (unsigned long) LgiGetCurrentThread(); } bool StartSSL(GAutoString &ErrorMsg, SslSocket *sock) { static LMutex Lock; if (Lock.Lock(_FL)) { if (!Library) { Library = new OpenSSL; if (Library && !Library->InitLibrary(sock)) { ErrorMsg = Library->ErrorMsg; DeleteObj(Library); } } Lock.Unlock(); } return Library != NULL; } void EndSSL() { DeleteObj(Library); } -struct SslSocketPriv +struct SslSocketPriv : public LCancel { GCapabilityClient *Caps; bool SslOnConnect; bool IsSSL; bool UseSSLrw; int Timeout; bool RawLFCheck; #ifdef _DEBUG bool LastWasCR; #endif bool IsBlocking; + LCancel *Cancel; // This is just for the UI. GStreamI *Logger; // This is for the connection logging. GAutoString LogFile; GAutoPtr LogStream; int LogFormat; SslSocketPriv() { #ifdef _DEBUG LastWasCR = false; #endif + Cancel = this; Timeout = 20 * 1000; IsSSL = false; UseSSLrw = false; LogFormat = 0; } }; bool SslSocket::DebugLogging = false; SslSocket::SslSocket(GStreamI *logger, GCapabilityClient *caps, bool sslonconnect, bool RawLFCheck) { d = new SslSocketPriv; Bio = 0; Ssl = 0; d->RawLFCheck = RawLFCheck; d->SslOnConnect = sslonconnect; d->Caps = caps; d->Logger = logger; d->IsBlocking = true; GAutoString ErrMsg; if (StartSSL(ErrMsg, this)) { #ifdef WIN32 if (Library->IsOk(this)) { char n[MAX_PATH]; char s[MAX_PATH]; if (GetModuleFileNameA(Library->LibSSL::Handle(), n, sizeof(n))) { sprintf_s(s, sizeof(s), "Using '%s'", n); OnInformation(s); } if (GetModuleFileNameA(Library->LibEAY::Handle(), n, sizeof(n))) { sprintf_s(s, sizeof(s), "Using '%s'", n); OnInformation(s); } } #endif } else if (caps) { caps->NeedsCapability("openssl", ErrMsg); } else { OnError(0, "Can't load or find OpenSSL library."); } } SslSocket::~SslSocket() { Close(); DeleteObj(d); } GStreamI *SslSocket::Clone() { return new SslSocket(d->Logger, d->Caps, true); } +LCancel *SslSocket::GetCancel() +{ + return d->Cancel; +} + +void SslSocket::SetCancel(LCancel *c) +{ + d->Cancel = c; +} + int SslSocket::GetTimeout() { return d->Timeout; } void SslSocket::SetTimeout(int ms) { d->Timeout = ms; } void SslSocket::SetLogger(GStreamI *logger) { d->Logger = logger; } void SslSocket::SetSslOnConnect(bool b) { d->SslOnConnect = b; } GStream *SslSocket::GetLogStream() { if (!d->LogStream && d->LogFile) { if (!d->LogStream.Reset(new GFile)) return NULL; if (!d->LogStream->Open(d->LogFile, O_WRITE)) return NULL; // Seek to the end d->LogStream->SetPos(d->LogStream->GetSize()); } return d->LogStream; } bool SslSocket::GetVariant(const char *Name, GVariant &Val, char *Arr) { if (!Name) return false; if (!_stricmp(Name, "isSsl")) // Type: Bool { Val = true; return true; } return false; } void SslSocket::Log(const char *Str, ssize_t Bytes, SocketMsgType Type) { if (!ValidStr(Str)) return; if (d->Logger) d->Logger->Write(Str, Bytes<0?(int)strlen(Str):Bytes, Type); else if (Type == SocketMsgError) LgiTrace("%.*s", Bytes, Str); } const char *SslSocket::GetErrorString() { return ErrMsg; } void SslSocket::SslError(const char *file, int line, const char *Msg) { char *Part = strrchr((char*)file, DIR_CHAR); #ifndef WIN32 printf("%s:%i - %s\n", file, line, Msg); #endif ErrMsg.Printf("Error: %s:%i - %s\n", Part ? Part + 1 : file, line, Msg); Log(ErrMsg, ErrMsg.Length(), SocketMsgError); } OsSocket SslSocket::Handle(OsSocket Set) { OsSocket h = INVALID_SOCKET; if (Set != INVALID_SOCKET) { long r; bool IsError = false; if (!Ssl) { Ssl = Library->SSL_new(Library->GetServer(this, NULL, NULL)); } if (Ssl) { - r = Library->SSL_set_fd(Ssl, Set); + r = Library->SSL_set_fd(Ssl, (int) Set); Bio = Library->SSL_get_rbio(Ssl); r = Library->SSL_accept(Ssl); if (r <= 0) IsError = true; else if (r == 1) h = Set; } else IsError = true; if (IsError) { long e = Library->ERR_get_error(); char *Msg = Library->ERR_error_string(e, 0); Log(Msg, -1, SocketMsgError); return INVALID_SOCKET; } } else if (Bio) { - uint32 hnd = INVALID_SOCKET; + int hnd = (int)INVALID_SOCKET; Library->BIO_get_fd(Bio, &hnd); h = hnd; } return h; } bool SslSocket::IsOpen() { return Bio != 0; } GString SslGetErrorAsString(OpenSSL *Library) { BIO *bio = Library->BIO_new (Library->BIO_s_mem()); Library->ERR_print_errors (bio); char *buf = NULL; size_t len = Library->BIO_get_mem_data (bio, &buf); GString s(buf, len); Library->BIO_free (bio); return s; } int SslSocket::Open(const char *HostAddr, int Port) { bool Status = false; LMutex::Auto Lck(&Lock, _FL); DebugTrace("%s:%i - SslSocket::Open(%s,%i)\n", _FL, HostAddr, Port); if (Library && Library->IsOk(this) && HostAddr) { char h[256]; sprintf_s(h, sizeof(h), "%s:%i", HostAddr, Port); // Do SSL handshake? if (d->SslOnConnect) { // SSL connection.. d->IsSSL = true; if (Library->Client) { const char *CertDir = "/u/matthew/cert"; - long r = Library->SSL_CTX_load_verify_locations(Library->Client, 0, CertDir); + int r = Library->SSL_CTX_load_verify_locations(Library->Client, 0, CertDir); DebugTrace("%s:%i - SSL_CTX_load_verify_locations=%i\n", _FL, r); if (r > 0) { Bio = Library->BIO_new_ssl_connect(Library->Client); DebugTrace("%s:%i - BIO_new_ssl_connect=%p\n", _FL, Bio); if (Bio) { Library->BIO_get_ssl(Bio, &Ssl); DebugTrace("%s:%i - BIO_get_ssl=%p\n", _FL, Ssl); if (Ssl) { // SNI setup Library->SSL_set_tlsext_host_name(Ssl, HostAddr); // Library->SSL_CTX_set_timeout() Library->BIO_set_conn_hostname(Bio, HostAddr); #if OPENSSL_VERSION_NUMBER < 0x10100000L Library->BIO_set_conn_int_port(Bio, &Port); #else GString sPort; - sPort.Printf("%i"); + sPort.Printf("%i", Port); Library->BIO_set_conn_port(Bio, sPort.Get()); #endif // Do non-block connect uint64 Start = LgiCurrentTime(); int To = GetTimeout(); IsBlocking(false); r = Library->SSL_connect(Ssl); DebugTrace("%s:%i - initial SSL_connect=%i\n", _FL, r); - while (r != 1 && !IsCancelled()) + while (r != 1 && !d->Cancel->IsCancelled()) { - long err = Library->SSL_get_error(Ssl, r); + int err = Library->SSL_get_error(Ssl, r); if (err != SSL_ERROR_WANT_CONNECT) { DebugTrace("%s:%i - SSL_get_error=%i\n", _FL, err); } LgiSleep(50); r = Library->SSL_connect(Ssl); DebugTrace("%s:%i - SSL_connect=%i (%i of %i ms)\n", _FL, r, (int)(LgiCurrentTime() - Start), (int)To); bool TimeOut = !HasntTimedOut(); if (TimeOut) { DebugTrace("%s:%i - SSL connect timeout, to=%i\n", _FL, To); SslError(_FL, "Connection timeout."); break; } } -DebugTrace("%s:%i - open loop finished, r=%i, Cancelled=%i\n", _FL, r, IsCancelled()); +DebugTrace("%s:%i - open loop finished, r=%i, Cancelled=%i\n", _FL, r, d->Cancel->IsCancelled()); if (r == 1) { IsBlocking(true); Library->SSL_set_mode(Ssl, SSL_MODE_AUTO_RETRY); Status = true; // d->UseSSLrw = true; char m[256]; sprintf_s(m, sizeof(m), "Connected to '%s' using SSL", h); OnInformation(m); } else { GString Err = SslGetErrorAsString(Library).Strip(); if (!Err) Err.Printf("BIO_do_connect(%s:%i) failed.", HostAddr, Port); SslError(_FL, Err); } } else SslError(_FL, "BIO_get_ssl failed."); } else SslError(_FL, "BIO_new_ssl_connect failed."); } else SslError(_FL, "SSL_CTX_load_verify_locations failed."); } else SslError(_FL, "No Ctx."); } else { Bio = Library->BIO_new_connect(h); DebugTrace("%s:%i - BIO_new_connect=%p\n", _FL, Bio); if (Bio) { // Non SSL... go into non-blocking mode so that if ::Close() is called we // can quit out of the connect loop. IsBlocking(false); uint64 Start = LgiCurrentTime(); int To = GetTimeout(); long r = Library->BIO_do_connect(Bio); DebugTrace("%s:%i - BIO_do_connect=%i\n", _FL, r); - while (r != 1 && !IsCancelled()) + while (r != 1 && !d->Cancel->IsCancelled()) { if (!Library->BIO_should_retry(Bio)) { break; } LgiSleep(50); r = Library->BIO_do_connect(Bio); DebugTrace("%s:%i - BIO_do_connect=%i\n", _FL, r); if (!HasntTimedOut()) { DebugTrace("%s:%i - open timeout, to=%i\n", _FL, To); OnError(0, "Connection timeout."); break; } } DebugTrace("%s:%i - open loop finished=%i\n", _FL, r); if (r == 1) { IsBlocking(true); Status = true; char m[256]; sprintf_s(m, sizeof(m), "Connected to '%s'", h); OnInformation(m); } else SslError(_FL, "BIO_do_connect failed"); } else SslError(_FL, "BIO_new_connect failed"); } } if (!Status) { Close(); } DebugTrace("%s:%i - SslSocket::Open status=%i\n", _FL, Status); return Status; } bool SslSocket::SetVariant(const char *Name, GVariant &Value, char *Arr) { bool Status = false; if (!Library || !Name) return false; if (!_stricmp(Name, SslSocket_LogFile)) { d->LogFile.Reset(Value.ReleaseStr()); } else if (!_stricmp(Name, SslSocket_LogFormat)) { d->LogFormat = Value.CastInt32(); } else if (!_stricmp(Name, GSocket_Protocol)) { char *v = Value.CastString(); if (v && stristr(v, "SSL")) { if (!Bio) { d->SslOnConnect = true; } else { if (!Library->Client) { SslError(_FL, "Library->Client is null."); } else { Ssl = Library->SSL_new(Library->Client); DebugTrace("%s:%i - SSL_new=%p\n", _FL, Ssl); if (!Ssl) { SslError(_FL, "SSL_new failed."); } else { int r = Library->SSL_set_bio(Ssl, Bio, Bio); DebugTrace("%s:%i - SSL_set_bio=%i\n", _FL, r); uint64 Start = LgiCurrentTime(); int To = GetTimeout(); while (HasntTimedOut()) { r = Library->SSL_connect(Ssl); DebugTrace("%s:%i - SSL_connect=%i\n", _FL, r); if (r < 0) LgiSleep(100); else break; } if (r > 0) { Status = d->UseSSLrw = d->IsSSL = true; OnInformation("Session is now using SSL"); X509 *ServerCert = Library->SSL_get_peer_certificate(Ssl); DebugTrace("%s:%i - SSL_get_peer_certificate=%p\n", _FL, ServerCert); if (ServerCert) { char Txt[256] = ""; Library->X509_NAME_oneline(Library->X509_get_subject_name(ServerCert), Txt, sizeof(Txt)); DebugTrace("%s:%i - X509_NAME_oneline=%s\n", _FL, Txt); OnInformation(Txt); } // SSL_get_verify_result } else { SslError(_FL, "SSL_connect failed."); r = Library->SSL_get_error(Ssl, r); char *Msg = Library->ERR_error_string(r, 0); if (Msg) { OnError(r, Msg); } } } } } } } return Status; } int SslSocket::Close() { - Cancel(true); + d->Cancel->Cancel(true); LMutex::Auto Lck(&Lock, _FL); if (Library) { if (Ssl) { DebugTrace("%s:%i - SSL_shutdown\n", _FL); int r = 0; if ((r = Library->SSL_shutdown(Ssl)) >= 0) { #ifdef WIN32 closesocket #else close #endif (Library->SSL_get_fd(Ssl)); } Library->SSL_free(Ssl); OnInformation("SSL connection closed."); // I think the Ssl object "owns" the Bio object... // So assume it gets fread by SSL_shutdown } else if (Bio) { DebugTrace("%s:%i - BIO_free\n", _FL); Library->BIO_free(Bio); OnInformation("Connection closed."); } Ssl = 0; Bio = 0; } else return false; return true; } bool SslSocket::Listen(int Port) { return false; } bool SslSocket::IsBlocking() { return d->IsBlocking; } void SslSocket::IsBlocking(bool block) { d->IsBlocking = block; if (Bio) { Library->BIO_set_nbio(Bio, !d->IsBlocking); } } bool SslSocket::IsReadable(int TimeoutMs) { // Assign to local var to avoid a thread changing it // on us between the validity check and the select. // Which is important because a socket value of -1 // (ie invalid) will crash the FD_SET macro. OsSocket s = Handle(); if (ValidSocket(s)) { struct timeval t = {TimeoutMs / 1000, (TimeoutMs % 1000) * 1000}; fd_set r; FD_ZERO(&r); FD_SET(s, &r); - int v = select(s+1, &r, 0, 0, &t); + int v = select((int)s+1, &r, 0, 0, &t); if (v > 0 && FD_ISSET(s, &r)) { return true; } else if (v < 0) { // Error(); } } else LgiTrace("%s:%i - Not a valid socket.\n", _FL); return false; } bool SslSocket::IsWritable(int TimeoutMs) { // Assign to local var to avoid a thread changing it // on us between the validity check and the select. // Which is important because a socket value of -1 // (ie invalid) will crash the FD_SET macro. OsSocket s = Handle(); if (ValidSocket(s)) { struct timeval t = {TimeoutMs / 1000, (TimeoutMs % 1000) * 1000}; fd_set w; FD_ZERO(&w); FD_SET(s, &w); - int v = select(s+1, &w, 0, 0, &t); + int v = select((int)s+1, &w, 0, 0, &t); if (v > 0 && FD_ISSET(s, &w)) { return true; } else if (v < 0) { // Error(); } } else LgiTrace("%s:%i - Not a valid socket.\n", _FL); return false; } void SslSocket::OnWrite(const char *Data, ssize_t Len) { #ifdef _DEBUG if (d->RawLFCheck) { const char *End = Data + Len; while (Data < End) { LgiAssert(*Data != '\n' || d->LastWasCR); d->LastWasCR = *Data == '\r'; Data++; } } #endif // Log(Data, Len, SocketMsgSend); } void SslSocket::OnRead(char *Data, ssize_t Len) { #ifdef _DEBUG if (d->RawLFCheck) { const char *End = Data + Len; while (Data < End) { LgiAssert(*Data != '\n' || d->LastWasCR); d->LastWasCR = *Data == '\r'; Data++; } } #endif // Log(Data, Len, SocketMsgReceive); } ssize_t SslSocket::Write(const void *Data, ssize_t Len, int Flags) { LMutex::Auto Lck(&Lock, _FL); if (!Library) { DebugTrace("%s:%i - Library is NULL\n", _FL); return -1; } if (!Bio) { DebugTrace("%s:%i - BIO is NULL\n", _FL); return -1; } - ssize_t r = 0; + int r = 0; if (d->UseSSLrw) { if (Ssl) { uint64 Start = LgiCurrentTime(); int To = GetTimeout(); while (HasntTimedOut()) { - r = Library->SSL_write(Ssl, Data, Len); + r = Library->SSL_write(Ssl, Data, (int)Len); if (r < 0) { LgiSleep(10); } else { DebugTrace("%s:%i - SSL_write(%p,%i)=%i\n", _FL, Data, Len, r); OnWrite((const char*)Data, r); break; } } if (r < 0) { DebugTrace("%s:%i - SSL_write failed (timeout=%i, %ims)\n", _FL, To, (int) (LgiCurrentTime() - Start)); } } else { r = -1; DebugTrace("%s:%i - No SSL\n", _FL); } } else { uint64 Start = LgiCurrentTime(); int To = GetTimeout(); while (HasntTimedOut()) { if (!Library) break; - r = Library->BIO_write(Bio, Data, Len); + r = Library->BIO_write(Bio, Data, (int)Len); DebugTrace("%s:%i - BIO_write(%p,%i)=%i\n", _FL, Data, Len, r); if (r < 0) { LgiSleep(10); } else { OnWrite((const char*)Data, r); break; } } if (r < 0) { DebugTrace("%s:%i - BIO_write failed (timeout=%i, %ims)\n", _FL, To, (int) (LgiCurrentTime() - Start)); } } if (r > 0) { GStream *l = GetLogStream(); if (l) l->Write(Data, r); } if (Ssl) { if (r < 0) { int Err = Library->SSL_get_error(Ssl, r); char Buf[256] = ""; char *e = Library->ERR_error_string(Err, Buf); DebugTrace("%s:%i - ::Write error %i, %s\n", _FL, Err, e); if (e) { OnError(Err, e); } } if (r <= 0) { DebugTrace("%s:%i - ::Write closing %i\n", _FL, r); Close(); } } return r; } ssize_t SslSocket::Read(void *Data, ssize_t Len, int Flags) { LMutex::Auto Lck(&Lock, _FL); if (!Library) return -1; if (Bio) { int r = 0; if (d->UseSSLrw) { if (Ssl) { uint64 Start = LgiCurrentTime(); int To = GetTimeout(); while (HasntTimedOut()) { - r = Library->SSL_read(Ssl, Data, Len); + r = Library->SSL_read(Ssl, Data, (int)Len); DebugTrace("%s:%i - SSL_read(%p,%i)=%i\n", _FL, Data, Len, r); if (r < 0) LgiSleep(10); else { OnRead((char*)Data, r); break; } } } else { r = -1; } } else { uint64 Start = LgiCurrentTime(); int To = GetTimeout(); while (HasntTimedOut()) { - r = Library->BIO_read(Bio, Data, Len); + r = Library->BIO_read(Bio, Data, (int)Len); if (r < 0) { if (d->IsBlocking) LgiSleep(10); else break; } else { DebugTrace("%s:%i - BIO_read(%p,%i)=%i\n", _FL, Data, Len, r); OnRead((char*)Data, r); break; } } } if (r > 0) { GStream *l = GetLogStream(); if (l) l->Write(Data, r); } if (Ssl && d->IsBlocking) { if (r < 0) { int Err = Library->SSL_get_error(Ssl, r); char Buf[256]; char *e = Library->ERR_error_string(Err, Buf); if (e) { OnError(Err, e); } Close(); } if (r <= 0) { Close(); } } return r; } return -1; } void SslSocket::OnError(int ErrorCode, const char *ErrorDescription) { DebugTrace("%s:%i - OnError=%i,%s\n", _FL, ErrorCode, ErrorDescription); GString s; s.Printf("Error %i: %s\n", ErrorCode, ErrorDescription); Log(s, s.Length(), SocketMsgError); } void SslSocket::DebugTrace(const char *fmt, ...) { if (DebugLogging) { char Buffer[512]; va_list Arg; va_start(Arg, fmt); int Ch = vsprintf_s(Buffer, sizeof(Buffer), fmt, Arg); va_end(Arg); if (Ch > 0) { // LgiTrace("SSL:%p: %s", this, Buffer); OnInformation(Buffer); } } } void SslSocket::OnInformation(const char *Str) { while (Str && *Str) { GAutoString a; const char *nl = Str; while (*nl && *nl != '\n') nl++; int Len = (int) (nl - Str + 2); a.Reset(new char[Len]); char *o; for (o = a; Str < nl; Str++) { if (*Str != '\r') *o++ = *Str; } *o++ = '\n'; *o++ = 0; LgiAssert((o-a) <= Len); Log(a, -1, SocketMsgInfo); Str = *nl ? nl + 1 : nl; } }