diff --git a/connection.cpp b/connection.cpp index 63aac45..3b8b3e7 100644 --- a/connection.cpp +++ b/connection.cpp @@ -165,13 +165,24 @@ void Connection::add (Packet * packet) } } -/* - * finds connection to which this packet belongs. - * a packet belongs to a connection if it matches - * to its reference packet - */ -Connection * findConnection (Packet * packet) -{ +Connection * findConnectionWithMatchingSource(Packet * packet) { + assert(packet->Outgoing()); + + ConnList * current = connections; + while (current != NULL) + { + /* the reference packet is always outgoing */ + if (packet->matchSource(current->getVal()->refpacket)) + { + return current->getVal(); + } + + current = current->getNext(); + } + return NULL; +} + +Connection * findConnectionWithMatchingRefpacketOrSource(Packet * packet) { ConnList * current = connections; while (current != NULL) { @@ -183,25 +194,26 @@ Connection * findConnection (Packet * packet) current = current->getNext(); } + return findConnectionWithMatchingSource(packet); +} - // Try again, now with the packet inverted: - current = connections; - Packet * invertedPacket = packet->newInverted(); - - while (current != NULL) +/* + * finds connection to which this packet belongs. + * a packet belongs to a connection if it matches + * to its reference packet + */ +Connection * findConnection (Packet * packet) +{ + if (packet->Outgoing()) + return findConnectionWithMatchingRefpacketOrSource(packet); + else { - /* the reference packet is always *outgoing* */ - if (invertedPacket->match(current->getVal()->refpacket)) - { - delete invertedPacket; - return current->getVal(); - } + Packet * invertedPacket = packet->newInverted(); + Connection * result = findConnectionWithMatchingRefpacketOrSource(invertedPacket); - current = current->getNext(); + delete invertedPacket; + return result; } - - delete invertedPacket; - return NULL; } /* diff --git a/packet.cpp b/packet.cpp index c961777..9f5a436 100644 --- a/packet.cpp +++ b/packet.cpp @@ -231,11 +231,12 @@ bool Packet::Outgoing () { dir = dir_outgoing; return true; } else { - /*if (DEBUG) { + if (DEBUG) { if (sa_family == AF_INET) islocal = local_addrs->contains(dip.s_addr); else islocal = local_addrs->contains(dip6); + if (!islocal) { std::cerr << "Neither dip nor sip are local: "; char addy [50]; @@ -246,7 +247,7 @@ bool Packet::Outgoing () { return false; } - }*/ + } dir = dir_incoming; return false; } @@ -273,7 +274,7 @@ char * Packet::gethashstring () inet_ntop(sa_family, &dip, remote_string, 49); } else { inet_ntop(sa_family, &sip6, local_string, 49); - inet_ntop(sa_family, &dip6, remote_string, 49); +inet_ntop(sa_family, &dip6, remote_string, 49); } if (Outgoing()) { snprintf(hashstring, HASHKEYSIZE * sizeof(char), "%s:%d-%s:%d", local_string, sport, remote_string, dport); @@ -294,3 +295,8 @@ bool Packet::match (Packet * other) return (sport == other->sport) && (dport == other->dport) && (sameinaddr(sip, other->sip)) && (sameinaddr(dip, other->dip)); } + +bool Packet::matchSource (Packet * other) +{ + return (sport == other->sport) && (sameinaddr(sip, other->sip)); +} diff --git a/packet.h b/packet.h index a993c20..7b28eec 100644 --- a/packet.h +++ b/packet.h @@ -74,6 +74,7 @@ public: bool Outgoing (); bool match (Packet * other); + bool matchSource (Packet * other); /* returns '1.2.3.4:5-1.2.3.4:6'-style string */ char * gethashstring(); private: