Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor packet tracking. #9

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions packet_tracking.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package probing

import (
"sync"
"time"

"github.com/google/uuid"
)

type PacketTracker struct {
currentUUID uuid.UUID
packets map[uuid.UUID]PacketSequence
sequence int
nextSequence int
timeout time.Duration
timeoutCh chan *inFlightPacket

mutex sync.RWMutex
}

type PacketSequence struct {
packets map[int]inFlightPacket
}

func (ps PacketSequence) NewInflightPacket(sequence int) {
ps.packets[sequence] = inFlightPacket{}
}

func (ps PacketSequence) GetPacket(sequence int) (inFlightPacket, bool) {
packet, ok := ps.packets[sequence]
return packet, ok
}

func (ps PacketSequence) RemovePacket(sequence int) {
delete(ps.packets, sequence)
}

type inFlightPacket struct {
timeoutTimer *time.Timer
}

func newPacketTracker(t time.Duration) *PacketTracker {
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})

return &PacketTracker{
packets: map[uuid.UUID]PacketSequence{},
sequence: 0,
timeout: t,
}
}

func (t *PacketTracker) AddPacket() int {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.nextSequence > 65535 {
newUUID := uuid.New()
t.packets[newUUID] = PacketSequence{}
t.currentUUID = newUUID
t.nextSequence = 0
}

t.sequence = t.nextSequence
t.packets[t.currentUUID].NewInflightPacket(t.sequence)
// if t.timeout > 0 {
// t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout)
// }
t.nextSequence++
return t.sequence
}

// DeletePacket removes a packet from the tracker.
func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) {
SuperQ marked this conversation as resolved.
Show resolved Hide resolved
t.mutex.Lock()
defer t.mutex.Unlock()

if t.hasPacket(u, seq) {
// if _, ok := t.packets[u].GetPacket(seq) ; ok != nil {
// t.packets[u][seq].timeoutTimer.Stop()
// }
t.packets[u].RemovePacket(seq)
}
}

func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool {
inflight, ok := t.packets[u]
if ok == false {
return ok
}
_, ok = inflight.GetPacket(seq)
return ok
}

// HasPacket checks the tracker to see if it's currently tracking a packet.
func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool {
t.mutex.RLock()
defer t.mutex.Unlock()

return t.hasPacket(u, seq)
}

func (t *PacketTracker) HasUUID(u uuid.UUID) bool {
_, hasUUID := t.packets[u]
return hasUUID
}

func (t *PacketTracker) CurrentUUID() uuid.UUID {
// t.mutex.RLock()
// defer t.mutex.Unlock()

return t.currentUUID
}
64 changes: 24 additions & 40 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,22 @@ var (
// New returns a new Pinger struct pointer.
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),

addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
}

Expand Down Expand Up @@ -142,6 +137,9 @@ type Pinger struct {
// Number of duplicate packets received
PacketsRecvDuplicates int

// Per-packet timeout
PacketTimeout time.Duration

// Round trip time statistics
minRtt time.Duration
maxRtt time.Duration
Expand Down Expand Up @@ -188,14 +186,11 @@ type Pinger struct {
ipaddr *net.IPAddr
addr string

// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID

ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
// tracker is a PacketTrackrer of UUIDs and sequence numbers.
tracker *PacketTracker
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
Expand Down Expand Up @@ -412,6 +407,9 @@ func (p *Pinger) Run() error {
if err != nil {
return err
}

p.tracker = newPacketTracker(p.PacketTimeout)

if conn, err = p.listen(); err != nil {
return err
}
Expand Down Expand Up @@ -614,19 +612,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}

for _, item := range p.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
if p.tracker.HasUUID(packetUUID) {
return &packetUUID, nil
}
return nil, nil
}

// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
}

func (p *Pinger) processPacket(recv *packet) error {
receivedAt := time.Now()
var proto int
Expand Down Expand Up @@ -675,15 +666,15 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
if !p.tracker.HasPacket(*pktUUID, pkt.Seq) {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
// Remove it from the list of sequences we're waiting for so we don't get duplicates.
p.tracker.DeletePacket(*pktUUID, pkt.Seq)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
Expand All @@ -704,7 +695,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
}

currentUUID := p.getCurrentTrackerUUID()
currentUUID := p.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
Expand Down Expand Up @@ -752,15 +743,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
handler(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.sequence = p.tracker.AddPacket()
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.sequence = 0
}
break
}

Expand Down
19 changes: 10 additions & 9 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestProcessPacket(t *testing.T) {
shouldBe1++
}

currentUUID := pinger.getCurrentTrackerUUID()
currentUUID := pinger.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
Expand All @@ -37,7 +37,7 @@ func TestProcessPacket(t *testing.T) {
Seq: pinger.sequence,
Data: data,
}
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
pinger.tracker.AddPacket()

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down Expand Up @@ -66,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) {
shouldBe0++
}

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) {
shouldBe0++
}

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) {
pinger := makeTestPinger()
pinger.Size = 4096

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -484,6 +484,7 @@ func makeTestPinger() *Pinger {
pinger.protocol = "icmp"
pinger.id = 123
pinger.Size = 0
pinger.tracker = newPacketTracker(time.Second * 5)

return pinger
}
Expand Down Expand Up @@ -542,7 +543,7 @@ func BenchmarkProcessPacket(b *testing.B) {
pinger.protocol = "ip4:icmp"
pinger.id = 123

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
b.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -591,7 +592,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
dups++
}

currentUUID := pinger.getCurrentTrackerUUID()
currentUUID := pinger.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
Expand All @@ -606,8 +607,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
Seq: 0,
Data: data,
}
// register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{}
// Register the sequence as sent.
pinger.tracker.AddPacket()

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down