Skip to content

Commit

Permalink
fix argument order for Matrix#element (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
yasuhito authored Jun 5, 2024
1 parent 1225c02 commit 8dd5d8a
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 40 deletions.
4 changes: 2 additions & 2 deletions apps/tutorial/serviceworker.js

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions apps/tutorial/serviceworker.js.map

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion apps/tutorial/src/serviceworker.js
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ const pickTargetAmplitudes = (targets, amplitudes) => {
if (each >= amplitudes.height) {
map[each] = [0, 0]
} else {
const c = amplitudes.element(0, each).value
const c = amplitudes.element(each, 0).value
map[each] = [c.real, c.imag]
}
return map
Expand Down
2 changes: 1 addition & 1 deletion apps/www/app/assets/javascripts/serviceworker.js
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ const pickTargetAmplitudes = (targets, amplitudes) => {
if (each >= amplitudes.height) {
map[each] = [0, 0]
} else {
const c = amplitudes.element(0, each).value
const c = amplitudes.element(each, 0).value
map[each] = [c.real, c.imag]
}
return map
Expand Down
60 changes: 29 additions & 31 deletions packages/simulator/src/matrix.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ export class Matrix {
}
}

return Matrix.create(width, height, buf)
return Matrix.create(height, width, buf)
}

private static create(width: number, height: number, buffer: Float64Array): Result<Matrix, Error> {
private static create(height: number, width: number, buffer: Float64Array): Result<Matrix, Error> {
if (width < 0) {
return err(Error(`width(${width}) < 0`))
}
Expand All @@ -86,17 +86,31 @@ export class Matrix {
return err(Error(`width(${width})*height(${height})*2 !== buffer.length(${buffer.length})`))
}

return ok(new Matrix(width, height, buffer))
return ok(new Matrix(height, width, buffer))
}

private constructor(width: number, height: number, buffer: Float64Array) {
this.width = width
private constructor(height: number, width: number, buffer: Float64Array) {
this.height = height
this.width = width
this.buffer = buffer

this.plus = this.add // alias for add
}

/**
* Retrieves the value at the specified row and column in the matrix.
*/
element(row: number, col: number): Result<Complex, Error> {
if (row < 0 || col < 0 || row >= this.height || col >= this.width) {
return err(Error('Element out of range'))
}

const ri = (this.width * row + col) * 2 // real part index
const ii = ri + 1 // imaginary part index

return ok(new Complex(this.buffer[ri], this.buffer[ii]))
}

timesQubitOperation(operation2x2: Matrix, qubitIndex: number, controlMask: number): Matrix {
Util.need((controlMask & (1 << qubitIndex)) === 0, 'Matrix.timesQubitOperation: self-controlled')
Util.need(operation2x2.width === 2 && operation2x2.height === 2, 'Matrix.timesQubitOperation: not 2x2')
Expand Down Expand Up @@ -128,23 +142,7 @@ export class Matrix {
}
}

return Matrix.create(w, h, buf)._unsafeUnwrap()
}

/**
* Returns element (col,row) of the matrix.
*
* @param col - The column index
* @param row - The row index
* @returns A result object with the element or an error
*/
element(col: number, row: number): Result<Complex, Error> {
if (col < 0 || row < 0 || col >= this.width || row >= this.height) {
return err(Error('Element out of range'))
}

const i = (this.width * row + col) * 2
return ok(new Complex(this.buffer[i], this.buffer[i + 1]))
return Matrix.create(h, w, buf)._unsafeUnwrap()
}

/**
Expand Down Expand Up @@ -183,7 +181,7 @@ export class Matrix {

const col = []
for (let row = 0; row < this.height; row++) {
col.push(this.element(colIndex, row)._unsafeUnwrap())
col.push(this.element(row, colIndex)._unsafeUnwrap())
}
return ok(col)
}
Expand All @@ -195,7 +193,7 @@ export class Matrix {
*/
rows(): Complex[][] {
return range(0, this.height - 1).map<Complex[]>(row =>
range(0, this.width - 1).map<Complex>(col => this.element(col, row)._unsafeUnwrap()),
range(0, this.width - 1).map<Complex>(col => this.element(row, col)._unsafeUnwrap()),
)
}

Expand Down Expand Up @@ -265,7 +263,7 @@ export class Matrix {
}
}

return new Matrix(w, h, newBuf)
return new Matrix(h, w, newBuf)
}

/**
Expand All @@ -286,7 +284,7 @@ export class Matrix {
newBuffer[i] = b1[i] + b2[i]
}

return ok(new Matrix(w, h, newBuffer))
return ok(new Matrix(h, w, newBuffer))
}

plus = this.add.bind(this)
Expand All @@ -309,7 +307,7 @@ export class Matrix {
newBuffer[i] = b1[i] - b2[i]
}

return ok(new Matrix(w, h, newBuffer))
return ok(new Matrix(h, w, newBuffer))
}

/**
Expand Down Expand Up @@ -356,7 +354,7 @@ export class Matrix {
}
}

return new Matrix(w, h, newBuffer)
return new Matrix(h, w, newBuffer)
}

/**
Expand Down Expand Up @@ -430,7 +428,7 @@ export class Matrix {
}

clone(): Matrix {
return new Matrix(this.width, this.height, this.buffer.slice())
return new Matrix(this.height, this.width, this.buffer.slice())
}

private norm2(): number {
Expand Down Expand Up @@ -469,7 +467,7 @@ export class Matrix {
}
}

return ok(new Matrix(w, h, newBuffer))
return ok(new Matrix(h, w, newBuffer))
}

private multScalar(v: number | Complex): Matrix {
Expand All @@ -484,6 +482,6 @@ export class Matrix {
newBuffer[i + 1] = vr * si + vi * sr
}

return new Matrix(this.width, this.height, newBuffer)
return new Matrix(this.height, this.width, newBuffer)
}
}
6 changes: 3 additions & 3 deletions packages/simulator/src/state-vector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class StateVector {
}

amplifier(index: number): Complex {
return this.matrix.element(0, index)._unsafeUnwrap()
return this.matrix.element(index, 0)._unsafeUnwrap()
}

setAmplifier(index: number, value: Complex): Result<StateVector, Error> {
Expand Down Expand Up @@ -146,9 +146,9 @@ export class StateVector {
if (!survived) continue

const amp = this.matrix
.element(0, ket)
.element(ket, 0)
._unsafeUnwrap()
.times(this.matrix.element(0, bra)._unsafeUnwrap().conjugate())
.times(this.matrix.element(bra, 0)._unsafeUnwrap().conjugate())
if (amp.isEqualTo(0)) continue

const ketMat =
Expand Down

0 comments on commit 8dd5d8a

Please sign in to comment.