From 5e9264130371915610dfc29ad78e9a2cbc3158d0 Mon Sep 17 00:00:00 2001 From: SuperQ Date: Fri, 17 Jun 2022 20:50:43 +0200 Subject: [PATCH] Refactor packet tracking. Signed-off-by: SuperQ --- packet_tracking.go | 114 +++++++++++++++++++++++++++++++++++++++++++++ ping.go | 64 ++++++++++--------------- ping_test.go | 19 ++++---- 3 files changed, 148 insertions(+), 49 deletions(-) create mode 100644 packet_tracking.go diff --git a/packet_tracking.go b/packet_tracking.go new file mode 100644 index 0000000..69b3deb --- /dev/null +++ b/packet_tracking.go @@ -0,0 +1,114 @@ +package probing + +import ( + "sync" + "time" + + "github.com/google/uuid" +) + +type PacketTracker struct { + currentUUID uuid.UUID + packets map[uuid.UUID]PacketSequence + sequence int + nextSequence int + timeout time.Duration + timeoutCh chan *inFlightPacket + + mutex sync.RWMutex +} + +type PacketSequence struct { + packets map[int]inFlightPacket +} + +func (ps PacketSequence) NewInflightPacket(sequence int) { + ps.packets[sequence] = inFlightPacket{} +} + +func (ps PacketSequence) GetPacket(sequence int) (inFlightPacket, bool) { + packet, ok := ps.packets[sequence] + return packet, ok +} + +func (ps PacketSequence) RemovePacket(sequence int) { + delete(ps.packets, sequence) +} + +type inFlightPacket struct { + timeoutTimer *time.Timer +} + +func newPacketTracker(t time.Duration) *PacketTracker { + firstUUID := uuid.New() + var firstSequence = map[uuid.UUID]map[int]struct{}{} + firstSequence[firstUUID] = make(map[int]struct{}) + + return &PacketTracker{ + packets: map[uuid.UUID]PacketSequence{}, + sequence: 0, + timeout: t, + } +} + +func (t *PacketTracker) AddPacket() int { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.nextSequence > 65535 { + newUUID := uuid.New() + t.packets[newUUID] = PacketSequence{} + t.currentUUID = newUUID + t.nextSequence = 0 + } + + t.sequence = t.nextSequence + t.packets[t.currentUUID].NewInflightPacket(t.sequence) + // if t.timeout > 0 { + // t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout) + // } + t.nextSequence++ + return t.sequence +} + +// DeletePacket removes a packet from the tracker. +func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.hasPacket(u, seq) { + // if _, ok := t.packets[u].GetPacket(seq) ; ok != nil { + // t.packets[u][seq].timeoutTimer.Stop() + // } + t.packets[u].RemovePacket(seq) + } +} + +func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool { + inflight, ok := t.packets[u] + if ok == false { + return ok + } + _, ok = inflight.GetPacket(seq) + return ok +} + +// HasPacket checks the tracker to see if it's currently tracking a packet. +func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool { + t.mutex.RLock() + defer t.mutex.Unlock() + + return t.hasPacket(u, seq) +} + +func (t *PacketTracker) HasUUID(u uuid.UUID) bool { + _, hasUUID := t.packets[u] + return hasUUID +} + +func (t *PacketTracker) CurrentUUID() uuid.UUID { + // t.mutex.RLock() + // defer t.mutex.Unlock() + + return t.currentUUID +} diff --git a/ping.go b/ping.go index ece069b..5ed02e8 100644 --- a/ping.go +++ b/ping.go @@ -86,9 +86,6 @@ var ( // New returns a new Pinger struct pointer. func New(addr string) *Pinger { r := rand.New(rand.NewSource(getSeed())) - firstUUID := uuid.New() - var firstSequence = map[uuid.UUID]map[int]struct{}{} - firstSequence[firstUUID] = make(map[int]struct{}) return &Pinger{ Count: -1, Interval: time.Second, @@ -96,17 +93,15 @@ func New(addr string) *Pinger { Size: timeSliceLength + trackerLength, Timeout: time.Duration(math.MaxInt64), - addr: addr, - done: make(chan interface{}), - id: r.Intn(math.MaxUint16), - trackerUUIDs: []uuid.UUID{firstUUID}, - ipaddr: nil, - ipv4: false, - network: "ip", - protocol: "udp", - awaitingSequences: firstSequence, - TTL: 64, - logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, + addr: addr, + done: make(chan interface{}), + id: r.Intn(math.MaxUint16), + ipaddr: nil, + ipv4: false, + network: "ip", + protocol: "udp", + TTL: 64, + logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, } } @@ -142,6 +137,9 @@ type Pinger struct { // Number of duplicate packets received PacketsRecvDuplicates int + // Per-packet timeout + PacketTimeout time.Duration + // Round trip time statistics minRtt time.Duration maxRtt time.Duration @@ -188,14 +186,11 @@ type Pinger struct { ipaddr *net.IPAddr addr string - // trackerUUIDs is the list of UUIDs being used for sending packets. - trackerUUIDs []uuid.UUID - ipv4 bool id int sequence int - // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts - awaitingSequences map[uuid.UUID]map[int]struct{} + // tracker is a PacketTrackrer of UUIDs and sequence numbers. + tracker *PacketTracker // network is one of "ip", "ip4", or "ip6". network string // protocol is "icmp" or "udp". @@ -412,6 +407,9 @@ func (p *Pinger) Run() error { if err != nil { return err } + + p.tracker = newPacketTracker(p.PacketTimeout) + if conn, err = p.listen(); err != nil { return err } @@ -614,19 +612,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) { return nil, fmt.Errorf("error decoding tracking UUID: %w", err) } - for _, item := range p.trackerUUIDs { - if item == packetUUID { - return &packetUUID, nil - } + if p.tracker.HasUUID(packetUUID) { + return &packetUUID, nil } return nil, nil } -// getCurrentTrackerUUID grabs the latest tracker UUID. -func (p *Pinger) getCurrentTrackerUUID() uuid.UUID { - return p.trackerUUIDs[len(p.trackerUUIDs)-1] -} - func (p *Pinger) processPacket(recv *packet) error { receivedAt := time.Now() var proto int @@ -675,15 +666,15 @@ func (p *Pinger) processPacket(recv *packet) error { inPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Seq = pkt.Seq // If we've already received this sequence, ignore it. - if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { + if !p.tracker.HasPacket(*pktUUID, pkt.Seq) { p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { p.OnDuplicateRecv(inPkt) } return nil } - // remove it from the list of sequences we're waiting for so we don't get duplicates. - delete(p.awaitingSequences[*pktUUID], pkt.Seq) + // Remove it from the list of sequences we're waiting for so we don't get duplicates. + p.tracker.DeletePacket(*pktUUID, pkt.Seq) p.updateStatistics(inPkt) default: // Very bad, not sure how this can happen @@ -704,7 +695,7 @@ func (p *Pinger) sendICMP(conn packetConn) error { dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} } - currentUUID := p.getCurrentTrackerUUID() + currentUUID := p.tracker.CurrentUUID() uuidEncoded, err := currentUUID.MarshalBinary() if err != nil { return fmt.Errorf("unable to marshal UUID binary: %w", err) @@ -752,15 +743,8 @@ func (p *Pinger) sendICMP(conn packetConn) error { handler(outPkt) } // mark this sequence as in-flight - p.awaitingSequences[currentUUID][p.sequence] = struct{}{} + p.sequence = p.tracker.AddPacket() p.PacketsSent++ - p.sequence++ - if p.sequence > 65535 { - newUUID := uuid.New() - p.trackerUUIDs = append(p.trackerUUIDs, newUUID) - p.awaitingSequences[newUUID] = make(map[int]struct{}) - p.sequence = 0 - } break } diff --git a/ping_test.go b/ping_test.go index 87fe448..9ded067 100644 --- a/ping_test.go +++ b/ping_test.go @@ -22,7 +22,7 @@ func TestProcessPacket(t *testing.T) { shouldBe1++ } - currentUUID := pinger.getCurrentTrackerUUID() + currentUUID := pinger.tracker.CurrentUUID() uuidEncoded, err := currentUUID.MarshalBinary() if err != nil { t.Fatalf("unable to marshal UUID binary: %s", err) @@ -37,7 +37,7 @@ func TestProcessPacket(t *testing.T) { Seq: pinger.sequence, Data: data, } - pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{} + pinger.tracker.AddPacket() msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -66,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) { shouldBe0++ } - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary() if err != nil { t.Fatalf("unable to marshal UUID binary: %s", err) } @@ -109,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) { shouldBe0++ } - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary() if err != nil { t.Fatalf("unable to marshal UUID binary: %s", err) } @@ -189,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) { pinger := makeTestPinger() pinger.Size = 4096 - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary() if err != nil { t.Fatalf("unable to marshal UUID binary: %s", err) } @@ -484,6 +484,7 @@ func makeTestPinger() *Pinger { pinger.protocol = "icmp" pinger.id = 123 pinger.Size = 0 + pinger.tracker = newPacketTracker(time.Second * 5) return pinger } @@ -542,7 +543,7 @@ func BenchmarkProcessPacket(b *testing.B) { pinger.protocol = "ip4:icmp" pinger.id = 123 - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary() if err != nil { b.Fatalf("unable to marshal UUID binary: %s", err) } @@ -591,7 +592,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { dups++ } - currentUUID := pinger.getCurrentTrackerUUID() + currentUUID := pinger.tracker.CurrentUUID() uuidEncoded, err := currentUUID.MarshalBinary() if err != nil { t.Fatalf("unable to marshal UUID binary: %s", err) @@ -606,8 +607,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { Seq: 0, Data: data, } - // register the sequence as sent - pinger.awaitingSequences[currentUUID][0] = struct{}{} + // Register the sequence as sent. + pinger.tracker.AddPacket() msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply,