Skip to content

Commit

Permalink
Merge pull request #47 from Photoroom/ben/random_sampling
Browse files Browse the repository at this point in the history
[DB API] Add random sampling support
  • Loading branch information
blefaudeux authored Nov 25, 2024
2 parents be6c978 + d21930b commit a971595
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
7 changes: 6 additions & 1 deletion pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,12 @@ func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result
}

func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Request {
request_url, _ := http.NewRequest("GET", api_url+"images/", nil)
if request.randomSampling {
api_url += "images/random/"
} else {
api_url += "images/"
}
request_url, _ := http.NewRequest("GET", api_url, nil)
request_url.Header.Add("Authorization", "Token "+api_key)
req := request_url.URL.Query()

Expand Down
50 changes: 49 additions & 1 deletion tests/client_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func get_default_test_config() datago.DatagoConfig {
db_config := datago.GetSourceDBConfig()
db_config.Sources = get_test_source()
db_config.PageSize = 32

config.SourceConfig = db_config
return config
}
Expand Down Expand Up @@ -379,3 +378,52 @@ func TestMultipleSources(t *testing.T) {
}
client.Stop()
}

func TestRandomSampling(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.RandomSampling = true
clientConfig.SourceConfig = dbConfig

// Fill in two sets with some results
sample_set_1 := make(map[string]interface{})
sample_set_2 := make(map[string]interface{})

{
client := datago.GetClient(clientConfig)

for i := 0; i < 10; i++ {
sample := client.GetSample()
sample_set_1[sample.ID] = nil
}
client.Stop()
}

{
client := datago.GetClient(clientConfig)

for i := 0; i < 10; i++ {
sample := client.GetSample()
sample_set_2[sample.ID] = nil
}
client.Stop()
}

// Check that the two sets are different
setsAreEqual := func(set1, set2 map[string]interface{}) bool {
if len(set1) != len(set2) {
return false
}
for k := range set1 {
if _, exists := set2[k]; !exists {
return false
}
}
return true
}

if setsAreEqual(sample_set_1, sample_set_2) {
t.Error("Random sampling is not working")
}
}

0 comments on commit a971595

Please sign in to comment.