From 562475aafbaae60f167071005959ecd68df3aec9 Mon Sep 17 00:00:00 2001 From: SuperQ Date: Fri, 17 Jun 2022 20:50:43 +0200 Subject: [PATCH] Refactor packet tracking. Signed-off-by: SuperQ --- packet_tracking.go | 97 ++++++++++++++++++++++++++++++++++++++++++++++ ping.go | 64 ++++++++++++------------------ 2 files changed, 121 insertions(+), 40 deletions(-) create mode 100644 packet_tracking.go diff --git a/packet_tracking.go b/packet_tracking.go new file mode 100644 index 0000000..b3b393b --- /dev/null +++ b/packet_tracking.go @@ -0,0 +1,97 @@ +package probing + +import ( + "sync" + "time" + + "github.com/google/uuid" +) + +type PacketTracker struct { + currentUUID uuid.UUID + packets map[uuid.UUID]PacketSequence + sequence int + nextSequence int + timeout time.Duration + timeoutCh chan *inFlightPacket + + mutex sync.RWMutex +} + +type PacketSequence struct { + packets map[uint]inFlightPacket +} + +type inFlightPacket struct { + timeoutTimer *time.Timer +} + +func NewPacketTracker(t time.Duration) *PacketTracker { + firstUUID := uuid.New() + var firstSequence = map[uuid.UUID]map[int]struct{}{} + firstSequence[firstUUID] = make(map[int]struct{}) + + return &PacketTracker{ + packets: map[uuid.UUID]PacketSequence{}, + sequence: 0, + timeout: t, + } +} + +func (t *PacketTracker) AddPacket() int { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.nextSequence > 65535 { + newUUID := uuid.New() + t.packets[newUUID] = PacketSequence{} + t.currentUUID = newUUID + t.nextSequence = 0 + } + + t.sequence = t.nextSequence + t.packets[t.currentUUID][t.sequence] = inFlightPacket{} + // if t.timeout > 0 { + // t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout) + // } + t.nextSequence++ + return t.sequence +} + +// DeletePacket removes a packet from the tracker. +func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.hasPacket(u, seq) { + if t.packets[u][seq] != nil { + t.packets[u][seq].timeoutTimer.Stop() + } + delete(t.packets[u], seq) + } +} + +func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool { + _, inflight := t.packets[u][seq] + return inflight +} + +// HasPacket checks the tracker to see if it's currently tracking a packet. +func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool { + t.mutex.RLock() + defer t.mutex.Unlock() + + return t.hasPacket(u, seq) +} + +func (t *PacketTracker) HasUUID(u uuid.UUID) bool { + _, hasUUID := t.packets[u] + return hasUUID +} + +func (t *PacketTracker) CurrentUUID() uuid.UUID { + t.mutex.RLock() + defer t.mutex.Unlock() + + return t.currentUUID +} diff --git a/ping.go b/ping.go index ab7a5e0..5327c7e 100644 --- a/ping.go +++ b/ping.go @@ -87,9 +87,6 @@ var ( // New returns a new Pinger struct pointer. func New(addr string) *Pinger { r := rand.New(rand.NewSource(getSeed())) - firstUUID := uuid.New() - var firstSequence = map[uuid.UUID]map[int]struct{}{} - firstSequence[firstUUID] = make(map[int]struct{}) return &Pinger{ Count: -1, Interval: time.Second, @@ -97,17 +94,15 @@ func New(addr string) *Pinger { Size: timeSliceLength + trackerLength, Timeout: time.Duration(math.MaxInt64), - addr: addr, - done: make(chan interface{}), - id: r.Intn(math.MaxUint16), - trackerUUIDs: []uuid.UUID{firstUUID}, - ipaddr: nil, - ipv4: false, - network: "ip", - protocol: "udp", - awaitingSequences: firstSequence, - TTL: 64, - logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, + addr: addr, + done: make(chan interface{}), + id: r.Intn(math.MaxUint16), + ipaddr: nil, + ipv4: false, + network: "ip", + protocol: "udp", + TTL: 64, + logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, } } @@ -143,6 +138,9 @@ type Pinger struct { // Number of duplicate packets received PacketsRecvDuplicates int + // Per-packet timeout + PacketTimeout time.Duration + // Round trip time statistics minRtt time.Duration maxRtt time.Duration @@ -189,14 +187,11 @@ type Pinger struct { ipaddr *net.IPAddr addr string - // trackerUUIDs is the list of UUIDs being used for sending packets. - trackerUUIDs []uuid.UUID - ipv4 bool id int sequence int - // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts - awaitingSequences map[uuid.UUID]map[int]struct{} + // tracker is a PacketTrackrer of UUIDs and sequence numbers. + tracker *PacketTracker // network is one of "ip", "ip4", or "ip6". network string // protocol is "icmp" or "udp". @@ -413,6 +408,9 @@ func (p *Pinger) Run() error { if err != nil { return err } + + p.tracker = NewPacketTracker(p.PacketTimeout) + if conn, err = p.listen(); err != nil { return err } @@ -615,19 +613,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) { return nil, fmt.Errorf("error decoding tracking UUID: %w", err) } - for _, item := range p.trackerUUIDs { - if item == packetUUID { - return &packetUUID, nil - } + if p.tracker.HasUUID(packetUUID) { + return &packetUUID, nil } return nil, nil } -// getCurrentTrackerUUID grabs the latest tracker UUID. -func (p *Pinger) getCurrentTrackerUUID() uuid.UUID { - return p.trackerUUIDs[len(p.trackerUUIDs)-1] -} - func (p *Pinger) processPacket(recv *packet) error { receivedAt := time.Now() var proto int @@ -676,15 +667,15 @@ func (p *Pinger) processPacket(recv *packet) error { inPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Seq = pkt.Seq // If we've already received this sequence, ignore it. - if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { + if !p.tracker.HasPacket(*pktUUID, pkt.Seq) { p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { p.OnDuplicateRecv(inPkt) } return nil } - // remove it from the list of sequences we're waiting for so we don't get duplicates. - delete(p.awaitingSequences[*pktUUID], pkt.Seq) + // Remove it from the list of sequences we're waiting for so we don't get duplicates. + p.tracker.DeletePacket(*pktUUID, pkt.Seq) p.updateStatistics(inPkt) default: // Very bad, not sure how this can happen @@ -705,7 +696,7 @@ func (p *Pinger) sendICMP(conn packetConn) error { dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} } - currentUUID := p.getCurrentTrackerUUID() + currentUUID := p.tracker.CurrentUUID() uuidEncoded, err := currentUUID.MarshalBinary() if err != nil { return fmt.Errorf("unable to marshal UUID binary: %w", err) @@ -753,15 +744,8 @@ func (p *Pinger) sendICMP(conn packetConn) error { handler(outPkt) } // mark this sequence as in-flight - p.awaitingSequences[currentUUID][p.sequence] = struct{}{} + p.sequence = p.tracker.AddPacket() p.PacketsSent++ - p.sequence++ - if p.sequence > 65535 { - newUUID := uuid.New() - p.trackerUUIDs = append(p.trackerUUIDs, newUUID) - p.awaitingSequences[newUUID] = make(map[int]struct{}) - p.sequence = 0 - } break }