Skip to content

Commit

Permalink
Refactor packet tracking.
Browse files Browse the repository at this point in the history
Signed-off-by: SuperQ <[email protected]>
  • Loading branch information
SuperQ committed Jun 18, 2022
1 parent 0caa487 commit 562475a
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 40 deletions.
97 changes: 97 additions & 0 deletions packet_tracking.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package probing

import (
"sync"
"time"

"github.com/google/uuid"
)

type PacketTracker struct {
currentUUID uuid.UUID
packets map[uuid.UUID]PacketSequence
sequence int
nextSequence int
timeout time.Duration
timeoutCh chan *inFlightPacket

mutex sync.RWMutex
}

type PacketSequence struct {
packets map[uint]inFlightPacket
}

type inFlightPacket struct {
timeoutTimer *time.Timer
}

func NewPacketTracker(t time.Duration) *PacketTracker {
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})

return &PacketTracker{
packets: map[uuid.UUID]PacketSequence{},
sequence: 0,
timeout: t,
}
}

func (t *PacketTracker) AddPacket() int {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.nextSequence > 65535 {
newUUID := uuid.New()
t.packets[newUUID] = PacketSequence{}
t.currentUUID = newUUID
t.nextSequence = 0
}

t.sequence = t.nextSequence
t.packets[t.currentUUID][t.sequence] = inFlightPacket{}
// if t.timeout > 0 {
// t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout)
// }
t.nextSequence++
return t.sequence
}

// DeletePacket removes a packet from the tracker.
func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.hasPacket(u, seq) {
if t.packets[u][seq] != nil {
t.packets[u][seq].timeoutTimer.Stop()
}
delete(t.packets[u], seq)
}
}

func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool {
_, inflight := t.packets[u][seq]
return inflight
}

// HasPacket checks the tracker to see if it's currently tracking a packet.
func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool {
t.mutex.RLock()
defer t.mutex.Unlock()

return t.hasPacket(u, seq)
}

func (t *PacketTracker) HasUUID(u uuid.UUID) bool {
_, hasUUID := t.packets[u]
return hasUUID
}

func (t *PacketTracker) CurrentUUID() uuid.UUID {
t.mutex.RLock()
defer t.mutex.Unlock()

return t.currentUUID
}
64 changes: 24 additions & 40 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,22 @@ var (
// New returns a new Pinger struct pointer.
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),

addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
}

Expand Down Expand Up @@ -143,6 +138,9 @@ type Pinger struct {
// Number of duplicate packets received
PacketsRecvDuplicates int

// Per-packet timeout
PacketTimeout time.Duration

// Round trip time statistics
minRtt time.Duration
maxRtt time.Duration
Expand Down Expand Up @@ -189,14 +187,11 @@ type Pinger struct {
ipaddr *net.IPAddr
addr string

// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID

ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
// tracker is a PacketTrackrer of UUIDs and sequence numbers.
tracker *PacketTracker
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
Expand Down Expand Up @@ -413,6 +408,9 @@ func (p *Pinger) Run() error {
if err != nil {
return err
}

p.tracker = NewPacketTracker(p.PacketTimeout)

if conn, err = p.listen(); err != nil {
return err
}
Expand Down Expand Up @@ -615,19 +613,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}

for _, item := range p.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
if p.tracker.HasUUID(packetUUID) {
return &packetUUID, nil
}
return nil, nil
}

// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
}

func (p *Pinger) processPacket(recv *packet) error {
receivedAt := time.Now()
var proto int
Expand Down Expand Up @@ -676,15 +667,15 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
if !p.tracker.HasPacket(*pktUUID, pkt.Seq) {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
// Remove it from the list of sequences we're waiting for so we don't get duplicates.
p.tracker.DeletePacket(*pktUUID, pkt.Seq)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
Expand All @@ -705,7 +696,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
}

currentUUID := p.getCurrentTrackerUUID()
currentUUID := p.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
Expand Down Expand Up @@ -753,15 +744,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
handler(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.sequence = p.tracker.AddPacket()
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.sequence = 0
}
break
}

Expand Down

0 comments on commit 562475a

Please sign in to comment.