Skip to content

Commit

Permalink
[FIX] Memory leaks, optimize int64 with RDX register divison
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrKrzem committed Oct 20, 2024
1 parent c89246b commit a33d79c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
60 changes: 55 additions & 5 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/random_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ MersenneTwisterGenerator<isa>::MersenneTwisterGenerator(const MersenneTwisterGen
template <x64::cpu_isa_t isa>
void MersenneTwisterGenerator<isa>::generate() {
this->preamble();
registersPool = RegistersPool::create(isa, {rax, rcx, rsp, rdi, k0});
registersPool = RegistersPool::create(isa, {rax, rcx, rsp, rdi, rdx, k0});

r64_dst = getReg64();
r64_state = getReg64();
Expand Down Expand Up @@ -747,6 +747,13 @@ template <x64::cpu_isa_t isa>
void MersenneTwisterGenerator<isa>::convertToOutputTypeMersenne(const Vmm& v_result, const Vmm& v_min, const Vmm& v_range, const Vmm& v_dst, const Xbyak::Reg64& r64_elements_remaining) {
using namespace Xbyak;

const auto r64_aux = getReg64();
const auto r64_aux_2 = getReg64();

const auto r32_aux = Xbyak::Reg32(r64_aux.getIdx());
const auto r32_aux_2 = Xbyak::Reg32(r64_aux_2.getIdx());


if (m_jcp.out_data_type == element::f32) {
// Apply mask and divisor
pand(v_result, v_mask);
Expand Down Expand Up @@ -790,11 +797,54 @@ void MersenneTwisterGenerator<isa>::convertToOutputTypeMersenne(const Vmm& v_res
movdqu(ptr[r64_dst], v_result);
} else if (m_jcp.out_data_type == element::i64) {
if (m_jcp.optimized) {
// Convert to int64 and store result
movdqu(ptr[r64_dst], v_result);
// Move the lower 32 bits of v_result to r32_aux
movd(r32_aux, v_result);

// Move the lower 32 bits of v_range to r32_aux_2
movd(r32_aux_2, v_range);

// Perform the modulo operation
xor_(rdx, rdx); // Clear RDX (set it to zero)
mov(rax, r32_aux); // Move r32_aux to RAX for division
div(r32_aux_2); // Divide RAX by r32_aux_2, quotient in RAX, remainder in RDX

// Move the remainder (result % range) to r32_aux
mov(r32_aux, rdx);

// Add v_min to r32_aux
movd(r32_aux_2, v_min);
add(r32_aux, r32_aux_2);

// Store the result in the destination pointer
mov(ptr[r64_dst], r32_aux);
} else {
// Convert to int64 with optimization and store result
movdqu(ptr[r64_dst], v_result);
// Extract the first two 32-bit values from v_result
movd(r32_aux, v_result); // Move the lower 32 bits of v_result[0] to r32_aux
pextrd(r32_aux_2, v_result, 1); // Extract the second 32-bit value (v_result[1]) to r32_aux_2

// Combine the two 32-bit values into a 64-bit integer
mov(r64_aux, r32_aux); // Move r32_aux to the lower 32 bits of r64_aux
shl(r64_aux, 32); // Shift r64_aux left by 32 bits
mov(r64_aux_2, r32_aux_2); // Move r32_aux_2 to r64_aux_2
or_(r64_aux, r64_aux_2); // Combine with r64_aux_2 to form a 64-bit integer

// Prepare for division
xor_(rdx, rdx); // Clear RDX (set it to zero)
mov(rax, r64_aux); // Move the combined 64-bit value to RAX

// Perform the division
mov(r64_aux_2, qword[v_range]); // Move the range value to r64_aux_2
div(r64_aux_2); // Divide RAX by r64_aux_2, quotient in RAX, remainder in RDX

// Move the remainder to r64_aux
mov(r64_aux, rdx); // Move the remainder (result % range) to r64_aux

// Add the minimum value
mov(r64_aux_2, qword[v_min]); // Move the minimum value to r64_aux_2
add(r64_aux, r64_aux_2); // Add r64_aux_2 to r64_aux

// Store the result in the destination pointer
mov(qword[r64_dst], r64_aux); // Move the final result to the memory location pointed by r64_dst
}
} else {
OPENVINO_THROW("RandomUniform kernel does not support precision ", m_jcp.out_data_type, " for ", x64::get_isa_info());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ struct MersenneTwisterGeneratorCallArgs {
uint64_t elements_remaining = 0lu;
bool optimization_enabled = false;
uint32_t out_data_type = 0u;

};

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
Expand Down

0 comments on commit a33d79c

Please sign in to comment.