diff --git a/src/connection.cpp b/src/connection.cpp index 9292a18..0959fa2 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -33,7 +33,7 @@ #include "nethogs.h" #include "process.h" -ConnList *connections = NULL; +ConnList connections; extern Process *unknownudp; void PackList::add(Packet *p) { @@ -78,7 +78,6 @@ u_int64_t PackList::sumanddel(timeval t) { /* packet may be deleted by caller */ Connection::Connection(Packet *packet) { assert(packet != NULL); - connections = new ConnList(this, connections); sent_packets = new PackList(); recv_packets = new PackList(); sumSent = 0; @@ -96,6 +95,7 @@ Connection::Connection(Packet *packet) { recv_packets->add(packet); refpacket = packet->newInverted(); } + connections.insert(this); lastpacket = packet->time.tv_sec; if (DEBUG) std::cout << "New reference packet created at " << refpacket << std::endl; @@ -104,6 +104,13 @@ Connection::Connection(Packet *packet) { Connection::~Connection() { if (DEBUG) std::cout << "Deleting connection" << std::endl; + auto r = connections.equal_range(this); + for (auto it = r.first; it != r.second; ++it) { + if (*it == this) { + connections.erase(it); + break; + } + } /* refpacket is not a pointer to one of the packets in the lists * so deleted */ delete (refpacket); @@ -111,24 +118,6 @@ Connection::~Connection() { delete sent_packets; if (recv_packets != NULL) delete recv_packets; - - ConnList *curr_conn = connections; - ConnList *prev_conn = NULL; - while (curr_conn != NULL) { - if (curr_conn->getVal() == this) { - ConnList *todelete = curr_conn; - curr_conn = curr_conn->getNext(); - if (prev_conn == NULL) { - connections = curr_conn; - } else { - prev_conn->setNext(curr_conn); - } - delete (todelete); - } else { - prev_conn = curr_conn; - curr_conn = curr_conn->getNext(); - } - } } /* the packet will be freed by the calling code */ @@ -156,26 +145,24 @@ Connection *findConnectionWithMatchingSource(Packet *packet, short int packettype) { assert(packet->Outgoing()); - ConnList *current = NULL; + ConnList *connList = NULL; switch (packettype) { case IPPROTO_TCP: { - current = connections; + connList = &connections; break; } case IPPROTO_UDP: { - current = unknownudp->connections; + connList = &unknownudp->connections; break; } } - while (current != NULL) { - /* the reference packet is always outgoing */ - if (packet->matchSource(current->getVal()->refpacket)) { - return current->getVal(); - } - - current = current->getNext(); + Packet p = packet->onlySource(); + auto it = connList->lower_bound(&p); + /* the reference packet is always outgoing */ + if (it != connList->end() && packet->matchSource((*it)->refpacket)) { + return *it; } return NULL; @@ -184,25 +171,23 @@ Connection *findConnectionWithMatchingSource(Packet *packet, Connection *findConnectionWithMatchingRefpacketOrSource(Packet *packet, short int packettype) { - ConnList *current = NULL; + ConnList *connList = NULL; switch (packettype) { case IPPROTO_TCP: { - current = connections; + connList = &connections; break; } case IPPROTO_UDP: { - current = unknownudp->connections; + connList = &unknownudp->connections; break; } } - while (current != NULL) { - /* the reference packet is always *outgoing* */ - if (packet->match(current->getVal()->refpacket)) { - return current->getVal(); - } - current = current->getNext(); + auto it = connList->lower_bound(packet); + /* the reference packet is always *outgoing* */ + if (it != connList->end() && packet->match((*it)->refpacket)) { + return *it; } return findConnectionWithMatchingSource(packet, packettype); diff --git a/src/cui.cpp b/src/cui.cpp index 29cfeb5..edfaef0 100644 --- a/src/cui.cpp +++ b/src/cui.cpp @@ -332,13 +332,10 @@ void show_trace(Line *lines[], int nproc) { } /* print the 'unknown' connections, for debugging */ - ConnList *curr_unknownconn = unknowntcp->connections; - while (curr_unknownconn != NULL) { + for (auto it = unknowntcp->connections.begin(); it != unknowntcp->connections.end(); ++it) { std::cout << "Unknown connection: " - << curr_unknownconn->getVal()->refpacket->gethashstring() + << (*it)->refpacket->gethashstring() << std::endl; - - curr_unknownconn = curr_unknownconn->getNext(); } } diff --git a/src/packet.cpp b/src/packet.cpp index 66ea101..34b8791 100644 --- a/src/packet.cpp +++ b/src/packet.cpp @@ -289,7 +289,7 @@ char *Packet::gethashstring() { /* 2 packets match if they have the same * source and destination ports and IP's. */ -bool Packet::match(Packet *other) { +bool Packet::match(const Packet *other) const { return sa_family == other->sa_family && (sport == other->sport) && (dport == other->dport) && (sa_family == AF_INET @@ -298,9 +298,46 @@ bool Packet::match(Packet *other) { (samein6addr(dip6, other->dip6))); } -bool Packet::matchSource(Packet *other) { +bool Packet::matchSource(const Packet *other) const { return sa_family == other->sa_family && (sport == other->sport) && (sa_family == AF_INET ? (sameinaddr(sip, other->sip)) : (samein6addr(sip6, other->sip6))); } + +Packet Packet::onlySource() const { + Packet p = *this; + std::fill(std::begin(p.dip6.s6_addr), std::end(p.dip6.s6_addr), 0); + p.dip.s_addr = 0; + p.dport = 0; + return p; +} + +bool Packet::operator< (const Packet& other) const { + if (sa_family != other.sa_family) + return dir < other.sa_family; + /* source address first */ + if (sport != other.sport) + return sport < other.sport; + if (sa_family == AF_INET) { + if (sip.s_addr != other.sip.s_addr) + return sip.s_addr < other.sip.s_addr; + } else { + for (int i = 0; i < 16; i++) + if (sip6.s6_addr[i] != other.sip6.s6_addr[i]) + return sip6.s6_addr[i] < other.sip6.s6_addr[i]; + } + /* destination address second */ + if (dport != other.dport) + return dport < other.dport; + if (sa_family == AF_INET) { + if (dip.s_addr != other.dip.s_addr) + return dip.s_addr < other.dip.s_addr; + } else { + for (int i = 0; i < 16; i++) + if (dip6.s6_addr[i] != other.dip6.s6_addr[i]) + return dip6.s6_addr[i] < other.dip6.s6_addr[i]; + } + /* equal */ + return false; +} diff --git a/src/packet.h b/src/packet.h index d158dba..0fde51b 100644 --- a/src/packet.h +++ b/src/packet.h @@ -70,8 +70,11 @@ public: /* is this packet coming from the local host? */ bool Outgoing(); - bool match(Packet *other); - bool matchSource(Packet *other); + bool match(const Packet *other) const; + bool matchSource(const Packet *other) const; + /* returns a copy with destination information stripped (for comparisons) */ + Packet onlySource() const; + bool operator< (const Packet& other) const; /* returns '1.2.3.4:5-1.2.3.4:6'-style string */ char *gethashstring(); diff --git a/src/process.cpp b/src/process.cpp index 321f1f7..1cfc741 100644 --- a/src/process.cpp +++ b/src/process.cpp @@ -97,13 +97,10 @@ void process_init() { int Process::getLastPacket() { int lastpacket = 0; - ConnList *curconn = connections; - while (curconn != NULL) { - assert(curconn != NULL); - assert(curconn->getVal() != NULL); - if (curconn->getVal()->getLastPacket() > lastpacket) - lastpacket = curconn->getVal()->getLastPacket(); - curconn = curconn->getNext(); + for (auto it = connections.begin(); it != connections.end(); ++it) { + assert(*it != NULL); + if ((*it)->getLastPacket() > lastpacket) + lastpacket = (*it)->getLastPacket(); } return lastpacket; } @@ -113,30 +110,20 @@ static void sum_active_connections(Process *process_ptr, u_int64_t &sum_sent, u_int64_t &sum_recv) { /* walk though all process_ptr process's connections, and sum * them up */ - ConnList *curconn = process_ptr->connections; - ConnList *previous = NULL; - while (curconn != NULL) { - if (curconn->getVal()->getLastPacket() <= curtime.tv_sec - CONNTIMEOUT) { + for (auto it = process_ptr->connections.begin(); it != process_ptr->connections.end(); ) { + if ((*it)->getLastPacket() <= curtime.tv_sec - CONNTIMEOUT) { /* capture sent and received totals before deleting */ - process_ptr->sent_by_closed_bytes += curconn->getVal()->sumSent; - process_ptr->rcvd_by_closed_bytes += curconn->getVal()->sumRecv; + process_ptr->sent_by_closed_bytes += (*it)->sumSent; + process_ptr->rcvd_by_closed_bytes += (*it)->sumRecv; /* stalled connection, remove. */ - ConnList *todelete = curconn; - Connection *conn_todelete = curconn->getVal(); - curconn = curconn->getNext(); - if (todelete == process_ptr->connections) - process_ptr->connections = curconn; - if (previous != NULL) - previous->setNext(curconn); - delete (todelete); - delete (conn_todelete); + delete (*it); + it = process_ptr->connections.erase(it); } else { u_int64_t sent = 0, recv = 0; - curconn->getVal()->sumanddel(curtime, &recv, &sent); + (*it)->sumanddel(curtime, &recv, &sent); sum_sent += sent; sum_recv += recv; - previous = curconn; - curconn = curconn->getNext(); + ++it; } } } @@ -171,12 +158,10 @@ void Process::getgbps(float *recvd, float *sent) { /** get total values for this process */ void Process::gettotal(u_int64_t *recvd, u_int64_t *sent) { u_int64_t sum_sent = 0, sum_recv = 0; - ConnList *curconn = this->connections; - while (curconn != NULL) { - Connection *conn = curconn->getVal(); + for (auto it = this->connections.begin(); it != this->connections.end(); ++it) { + Connection *conn = (*it); sum_sent += conn->sumSent; sum_recv += conn->sumRecv; - curconn = curconn->getNext(); } // std::cout << "Sum sent: " << sum_sent << std::endl; // std::cout << "Sum recv: " << sum_recv << std::endl; @@ -403,7 +388,7 @@ Process *getProcess(Connection *connection, const char *devicename, processes = new ProcList(proc, processes); } - proc->connections = new ConnList(connection, proc->connections); + proc->connections.insert(connection); return proc; } diff --git a/src/process.h b/src/process.h index 577146e..3730e33 100644 --- a/src/process.h +++ b/src/process.h @@ -26,32 +26,30 @@ #include "connection.h" #include "nethogs.h" #include +#include extern bool tracemode; extern bool bughuntmode; void check_all_procs(); -class ConnList { -public: - ConnList(Connection *m_val, ConnList *m_next) { - assert(m_val != NULL); - val = m_val; - next = m_next; +/* compares Connection pointers by their refpacket */ +struct ConnectionComparator { + using is_transparent = void; + bool operator()(const Connection* l, const Connection* r) const { + return *l->refpacket < *r->refpacket; } - ~ConnList() { - /* does not delete its value, to allow a connection to - * remove itself from the global connlist in its destructor */ + bool operator()(const Packet* l, const Connection* r) const { + return *l < *r->refpacket; + } + bool operator()(const Connection* l, const Packet* r) const { + return *l->refpacket < *r; } - Connection *getVal() { return val; } - void setNext(ConnList *m_next) { next = m_next; } - ConnList *getNext() { return next; } - -private: - Connection *val; - ConnList *next; }; +/* ordered set of Connection pointers */ +typedef std::multiset ConnList; + class Process { public: /* the process makes a copy of the name. the device name needs to be stable. @@ -75,7 +73,6 @@ public: cmdline = strdup(m_cmdline); devicename = m_devicename; - connections = NULL; pid = 0; uid = 0; sent_by_closed_bytes = 0; @@ -106,7 +103,7 @@ public: u_int64_t sent_by_closed_bytes; u_int64_t rcvd_by_closed_bytes; - ConnList *connections; + ConnList connections; uid_t getUid() { return uid; } void setUid(uid_t m_uid) { uid = m_uid; }