From 2fa039f9bb323c423020a1f72953342ae38872f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladimir=20Milovanovi=C4=87?= Date: Sun, 30 Jul 2023 22:38:43 +0200 Subject: [PATCH] Bump to chisel3.6 (#247) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bump breeze and spire versions * Bump spire to 0.18.0 * Bump breeze to 2.1.0 * Override method 'reverse' in SignAlgebra due to both Order[A] and CMonoid[A] now having 'reverse' methods each. * Remove parameterless non-Unit method definitions and invocations * Bump Chisel version to 3.6 * Bump dsptools version to 1.6 * Bump chiseltest version to 0.6 * Bump scalatest version to 3.2.15 * Bump fixedpoint to version supporting Chisel 3.6.0 * Apply fixedpoint.shadow.Mux to more places * Add '-Xfatal-warnings' compiler option * Remove Saturate as it uses firrtl transformations * Add branch 1.6.x to CI workflow * Update scalafmt settings * Add sclaafmt sbt plugin * Reformat source files according to scalafmt rules * Add scalafmt step to CI * Changed license from BSD to Apache-2.0 in README.md --------- Co-authored-by: Aleksandar Kondić --- .github/workflows/test.yml | 4 + .scalafmt.conf | 8 +- README.md | 74 +++-- build.sbt | 15 +- fixedpoint | 2 +- project/plugins.sbt | 1 + .../counters/CounterWithReset.scala | 2 +- .../counters/ShiftRegisterWithReset.scala | 6 +- .../dsptools/dspmath/ExtendedEuclid.scala | 3 +- .../dsptools/dspmath/Factorization.scala | 24 +- src/main/scala/dsptools/misc/BitWidth.scala | 18 +- .../scala/dsptools/misc/DspException.scala | 3 +- .../dsptools/misc/DspTesterUtilities.scala | 6 +- .../scala/dsptools/numbers/DspContext.scala | 18 +- .../dsptools/numbers/algebra_types/Eq.scala | 10 +- .../numbers/algebra_types/Order.scala | 5 +- .../numbers/algebra_types/PartialOrder.scala | 14 +- .../dsptools/numbers/algebra_types/Ring.scala | 6 +- .../numbers/algebra_types/Signed.scala | 16 +- .../numbers/algebra_types/helpers/Sign.scala | 45 +-- .../binary_types/BinaryRepresentation.scala | 8 +- .../numbers/binary_types/NumberBits.scala | 8 +- .../DspRealVerilatorBB.scala | 2 +- .../blackbox_compatibility/TrigUtility.scala | 37 +-- .../numbers/chisel_concrete/DspComplex.scala | 33 +- .../numbers/chisel_concrete/DspReal.scala | 139 +++++---- .../numbers/chisel_concrete/RealTrig.scala | 28 +- .../chisel_types/DspComplexTypeClass.scala | 57 ++-- .../chisel_types/DspRealTypeClass.scala | 70 +++-- .../chisel_types/FixedPointTypeClass.scala | 198 ++++++------ .../numbers/chisel_types/SIntTypeClass.scala | 68 +++-- .../numbers/chisel_types/UIntTypeClass.scala | 76 ++--- .../convertible_types/ConvertableTo.scala | 4 +- .../dsptools/numbers/implicits/AllOps.scala | 106 +++---- .../numbers/implicits/ImplicitSyntax.scala | 24 +- .../numbers/implicits/ImplicitsTop.scala | 16 +- .../numbers/number_types/Numbers.scala | 20 +- src/main/scala/dsptools/numbers/package.scala | 27 +- .../numbers/representations/BaseN.scala | 3 +- .../dsptools/numbers/rounding/Saturate.scala | 282 ------------------ .../examples/StreamingAutocorrelator.scala | 14 +- .../examples/TransposedStreamingFIR.scala | 29 +- src/main/scala/examples/gainOffCorr.scala | 20 +- 43 files changed, 680 insertions(+), 869 deletions(-) delete mode 100644 src/main/scala/dsptools/numbers/rounding/Saturate.scala diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 575cbfbc..2b306bf8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,6 +5,7 @@ on: push: branches: - master + - 1.6.x - 1.5.x - 1.4.x - 1.3.x @@ -25,6 +26,9 @@ jobs: uses: coursier/setup-action@v1 - name: Cache uses: coursier/cache-action@v6 + - name: Formatting check + id: scalafmt + run: sbt scalafmtCheckAll - name: Documentation id: doc run: sbt doc diff --git a/.scalafmt.conf b/.scalafmt.conf index c61875e6..f74e5504 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,3 +1,5 @@ +version = 2.6.4 + maxColumn = 120 align = most continuationIndent.defnSite = 2 @@ -13,7 +15,11 @@ align.tokens.add = [ } ] -newlines.alwaysBeforeCurlyBraceLambdaParams
 = false +newlines.alwaysBeforeCurlyBraceLambdaParams = false +newlines.alwaysBeforeMultilineDef = false +newlines.implicitParamListModifierForce = [before] + +verticalMultiline.atDefnSite = true optIn.annotationNewlines = true diff --git a/README.md b/README.md index 45e655c5..31931f4e 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,11 @@ Some of the goals of dsptools are to enable: 2. Enhanced support for designing and testing DSP with generic types (i.e. switching between **DSPReal** for verifying functional correctness with double-precision floating point and **FixedPoint** for evaluating fixed-point design metrics by changing a single parameter). 3. More useful and universal testing platform for numeric types! - > Numbers are displayed in their correct formats instead of hex for peek, poke, and expect operations. Additionally, if your tester extends **DSPTester**, you can optionally dump your test sequence to a **Verilog testbench** that replays the test for functional verification on all simulation platforms (i.e. Xilinx, Altera, etc. instead of only VCS). The tolerance of comparisons with expected values can also be changed via `DSPTester.setTol(floTol = decimal_tolerance, - fixedTol = number_of_bits)`. + > Numbers are displayed in their correct formats instead of hex for peek, poke, and expect operations. Additionally, if your tester extends **DSPTester**, you can optionally dump your test sequence to a **Verilog testbench** that replays the test for functional verification on all simulation platforms (i.e. Xilinx, Altera, etc. instead of only VCS). The tolerance of comparisons with expected values can also be changed via `DSPTester.setTol(floTol = decimal_tolerance, fixedTol = number_of_bits)`. 4. **Miscellaneous additional features** - - Wide range of LUT modules for ease of generating lookup tables from pre-calculated constants (no intermediate representation) - - Memory modules that abstract out confusion associated with Chisel Mem + - Wide range of LUT modules for ease of generating lookup tables from pre-calculated constants (no intermediate representation). + - Memory modules that abstract out confusion associated with Chisel `Mem`. - Generates useful helper files with each Verilog output (constraints, generator parameters used, etc.). - Easier to rename modules & signals and have renaming actually succeed. - Expanding Support for non-base-2 math. @@ -50,46 +49,42 @@ See Github for the latest release. Snapshots are also published on Sonatype, which are beneficial if you want to use the latest features. Projects that dsptools depends on are: - -- [FIRRTL](https://github.com/ucb-bar/firrtl) - -- [FIRRTL Interpreter](https://github.com/ucb-bar/firrtl-interpreter) - -- [Chisel3](https://github.com/ucb-bar/chisel3) - -- [Chisel Testers](https://github.com/ucb-bar/chisel-testers) +* [FIRRTL](https://github.com/ucb-bar/firrtl) +* [FIRRTL Interpreter](https://github.com/ucb-bar/firrtl-interpreter) +* [Chisel3](https://github.com/ucb-bar/chisel3) +* [Chisel Testers](https://github.com/ucb-bar/chisel-testers) ---------- Numeric Typeclasses -=============== +=================== This library defines a number of typeclasses for numeric types. -A brief explanation of how typeclasses work in scala can be found [here](http://typelevel.org/cats/typeclasses.html) and [here](http://blog.jaceklaskowski.pl/2015/05/15/ad-hoc-polymorphism-in-scala-with-type-classes.html). +A brief explanation of how typeclasses work in scala can be found [here](http://typelevel.org/cats/typeclasses.html). Our DSP-specific typeclasses are built on top of [spire](https://github.com/non/spire). -The goal of these typeclasses is to make it easy to write chisel modules that treat the number representation as a parameter. -For example, using typeclasses you can write chisel that generates an FIR filter for both real and complex numbers. -You can also use typeclasses to write chisel that generates a circuit implementation using floating point (via Verilog's real type). +The goal of these typeclasses is to make it easy to write Chisel modules that treat the number representation as a parameter. +For example, using typeclasses you can write Chisel that generates an FIR filter for both real and complex numbers. +You can also use typeclasses to write Chisel that generates a circuit implementation using floating point (via Verilog's real type). After testing that your circuit implementation works with floating point, you can use the same code to generate a fixed point version of the circuit suitable for synthesis. **For a additional, more detailed description of the Numeric classes in dsptools: see [The Numbers ReadMe](https://github.com/ucb-bar/dsptools/blob/master/src/main/scala/dsptools/numbers/README.md)** -A generic function in scala is defined like so: +A generic function in Scala programming language is defined like so: ```def func[T](in: T): T``` -This means that you can call `func(obj)` for an object of any type. If `obj` is of type `Q`, you can write `func[Q](obj)` to specify that we want the `Q` version of the generic function `func`, but this is only necessary if the scala compiler can't figure out what `Q` is supposed to be. +This means that you can call `func(obj)` for an object of any type. If `obj` is of type `Q`, you can write `func[Q](obj)` to specify that we want the `Q` version of the generic function `func`, but this is only necessary if the Scala compiler can't figure out what `Q` is supposed to be. You can also write ```class SomeClass[T]``` and use `T` like it is a real type for any member functions of variables. -To write a generic chisel Module, we might try to write +To write a generic Chisel Module, we might try to write -``` +```scala class Passthrough[T](gen: T) extends Module { val io = new IO(Bundle { val in = Input(gen) @@ -102,18 +97,18 @@ class Passthrough[T](gen: T) extends Module { Here, `gen` is a parameter specifying the type you want to use for your IO's, so you could write `Module(new Passthrough(SInt(width=10)))` or `Module(new Passthrough(new Bundle { ... }))`. Unfortunately, there's a problem with this. `T` can be any type, and a lot of types don't make sense, like `String` or `()=>Unit`. -This will not compile, because `Input()`, `Output()`, and `:=` are functions defined on chisel types. +This will not compile, because `Input()`, `Output()`, and `:=` are functions defined on Chisel types. We can fix this problem by writing ```class Passthrough[T<:Data](gen: T) extends Module``` -This type constraint means that we have to choose `T` to be a subtype of the chisel type `Data`. +This type constraint means that we have to choose `T` to be a subtype of the Chisel type `Data`. Things like `UInt`, `SInt`, and `Bundle` are subtypes of `Data`. Now the example above should compile. This example isn't very interesting, though. `Data` lets you do basic things like assignment and make registers, but doesn't define any mathematical operations, so if we write -``` +```scala class Doubler[T<:Data](gen: T) extends Module { val io = IO(new Bundle { val in = Input(gen) @@ -127,7 +122,7 @@ it won't compile. This is where typeclasses come in. This library defines a trait -``` +```scala trait Real[T] { ... def plus(x: T, y: T): T @@ -137,14 +132,14 @@ trait Real[T] { as well as an implicit conversion so that `a+b` gets converted to `Real[T].plus(a,b)`. `Real[T]` is a typeclass. -Typeclasses are a useful pattern in scala, so there is syntactic sugar to make using them easy: +Typeclasses are a useful pattern in Scala, so there is syntactic sugar to make using them easy: -``` +```scala import dsptools.numbers._ class Doubler[T<:Data:Real](gen: T) extends Module ``` -Note: If you don't include the `:Real` at the end, the scala compiler will think `io.in + io.in` is string concatenation and you'll get a weird error saying +*Note*: If you don't include the `:Real` at the end, the Scala compiler will think `io.in + io.in` is string concatenation and you'll get a weird error saying ``` [error] found : T @@ -152,34 +147,35 @@ Note: If you don't include the `:Real` at the end, the scala compiler will think ``` Some useful typeclasses: -- Ring +* Ring - defines +, *, -, **, zero, one - defined in [Spire](https://github.com/non/spire) - Read: https://en.wikipedia.org/wiki/Ring_(mathematics) - Note: We chose to restrict ourselves to `Ring` rather than `Field` because division is particularly expensive and nuanced in hardware. Rather than typing `a / b` we think it is better to require users to instantiate a module and think about what's going on. -- Eq - - defines === and =/= (returning chisel Bools!) -- PartialOrder +* Eq + - defines === and =/= (returning Chisel Bools!) +* PartialOrder - extends Eq - defines >, <, <=, >= (returning a `ValidIO[ComparisonBundle]` that has `valid` false if the objects are not comparable -- Order +* Order - extends PartialOrder - defines >, <, <=, >=, min, max -- Sign +* Sign - defines abs, isSignZero, isSignPositive, isSignNegative, isSignNonZero, isSignNonPositive, isSignNonNegative -- Real +* Real - extends Ring with Order with Sign - defines ceil, round, floor, isWhole - defines a bunch of conversion methods from ConvertableTo, e.g. fromDouble, fromInt -- Integer +* Integer - extends Real - defines mod ---------- Rocket-chip -=============== +=========== + Integration of dsptools with a rocket-chip based project: The github project [Rocket Dsp Utils](https://github.com/chick/rocket-dsp-utils) contains useful tools @@ -189,6 +185,6 @@ These tools formerly were contained in this repo under the `rocket` sub-director ---------- -This code is maintained by [Chick](https://github.com/chick), [Angie](https://github.com/shunshou) and [Paul](https://github.com/grebe). Let us know if you have any questions/feedback! +This code was maintained by [Chick](https://github.com/chick), [Angie](https://github.com/shunshou) and [Paul](https://github.com/grebe). Let us know if you have any questions/feedback! -Copyright (c) 2015 - 2021 The Regents of the University of California. Released under the Modified (3-clause) BSD license. +Copyright (c) 2015 - 2022 The Regents of the University of California. Released under the Apache-2.0 license. diff --git a/build.sbt b/build.sbt index e51367ec..200e6c56 100644 --- a/build.sbt +++ b/build.sbt @@ -5,15 +5,15 @@ enablePlugins(SiteScaladocPlugin) enablePlugins(GhpagesPlugin) val defaultVersions = Map( - "chisel3" -> "3.5-SNAPSHOT", - "chiseltest" -> "0.5-SNAPSHOT" + "chisel3" -> "3.6-SNAPSHOT", + "chiseltest" -> "0.6-SNAPSHOT" ) name := "dsptools" val commonSettings = Seq( organization := "edu.berkeley.cs", - version := "1.5-SNAPSHOT", + version := "1.6-SNAPSHOT", git.remoteRepo := "git@github.com:ucb-bar/dsptools.git", autoAPIMappings := true, scalaVersion := "2.13.10", @@ -23,6 +23,7 @@ val commonSettings = Seq( "-deprecation", "-feature", "-language:reflectiveCalls", + "-Xfatal-warnings", "-Ymacro-annotations"), javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), pomExtra := (http://chisel.eecs.berkeley.edu/ @@ -78,15 +79,15 @@ val commonSettings = Seq( val dsptoolsSettings = Seq( name := "dsptools", libraryDependencies ++= Seq( - "org.typelevel" %% "spire" % "0.17.0", - "org.scalanlp" %% "breeze" % "1.1", - "org.scalatest" %% "scalatest" % "3.2.+" % "test" + "org.typelevel" %% "spire" % "0.18.0", + "org.scalanlp" %% "breeze" % "2.1.0", + "org.scalatest" %% "scalatest" % "3.2.15" % "test" ), ) val fixedpointSettings = Seq( libraryDependencies ++= Seq( - "org.scalatest" %% "scalatest" % "3.2.+" % "test", + "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatestplus" %% "scalacheck-1-14" % "3.2.2.0" % "test", ) ) diff --git a/fixedpoint b/fixedpoint index 4e69127e..3d5ae17e 160000 --- a/fixedpoint +++ b/fixedpoint @@ -1 +1 @@ -Subproject commit 4e69127e1897c8e3ca809748cc9491597cdfcc06 +Subproject commit 3d5ae17eb6dd90353b3d825761605df355f73560 diff --git a/project/plugins.sbt b/project/plugins.sbt index 764cb8b8..ef70a41c 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -13,3 +13,4 @@ addSbtPlugin("com.eed3si9n" % "sbt-sriracha" % "0.1.0") addSbtPlugin("com.geirsson" % "sbt-ci-release" % "1.5.4") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") diff --git a/src/main/scala/dsptools/TODO/move_to_utilities/counters/CounterWithReset.scala b/src/main/scala/dsptools/TODO/move_to_utilities/counters/CounterWithReset.scala index dd392bc9..660d18db 100644 --- a/src/main/scala/dsptools/TODO/move_to_utilities/counters/CounterWithReset.scala +++ b/src/main/scala/dsptools/TODO/move_to_utilities/counters/CounterWithReset.scala @@ -7,7 +7,7 @@ import chisel3._ object CounterWithReset { def apply(cond: Bool, n: Int, reset: Bool): (UInt, Bool) = { val c = chisel3.util.Counter(cond, n) - if (n > 1) { when (reset) { c._1 := 0.U } } + if (n > 1) { when(reset) { c._1 := 0.U } } c } } diff --git a/src/main/scala/dsptools/TODO/move_to_utilities/counters/ShiftRegisterWithReset.scala b/src/main/scala/dsptools/TODO/move_to_utilities/counters/ShiftRegisterWithReset.scala index e7337386..8784c45a 100644 --- a/src/main/scala/dsptools/TODO/move_to_utilities/counters/ShiftRegisterWithReset.scala +++ b/src/main/scala/dsptools/TODO/move_to_utilities/counters/ShiftRegisterWithReset.scala @@ -5,8 +5,8 @@ package dsptools.counters import chisel3._ import chisel3.util.RegEnable -object ShiftRegisterWithReset -{ +object ShiftRegisterWithReset { + /** Returns the n-cycle delayed version of the input signal. * * @param in input to delay @@ -16,7 +16,7 @@ object ShiftRegisterWithReset def apply[T <: Data](in: T, n: Int, reset: T, en: Bool = true.B): T = { // The order of tests reflects the expected use cases. if (n != 0) { - RegEnable(apply(in, n-1, reset, en), reset, en) + RegEnable(apply(in, n - 1, reset, en), reset, en) } else { in } diff --git a/src/main/scala/dsptools/dspmath/ExtendedEuclid.scala b/src/main/scala/dsptools/dspmath/ExtendedEuclid.scala index b8394493..87400ef7 100644 --- a/src/main/scala/dsptools/dspmath/ExtendedEuclid.scala +++ b/src/main/scala/dsptools/dspmath/ExtendedEuclid.scala @@ -3,6 +3,7 @@ package dsptools.dspmath object ExtendedEuclid { + /** Extended Euclidean Algorithm * ax + by = gcd(a, b) * Inputs: a, b @@ -16,4 +17,4 @@ object ExtendedEuclid { (gcd, x - (b / a) * y, y) } } -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/dspmath/Factorization.scala b/src/main/scala/dsptools/dspmath/Factorization.scala index af6f4a2d..c982e1c5 100644 --- a/src/main/scala/dsptools/dspmath/Factorization.scala +++ b/src/main/scala/dsptools/dspmath/Factorization.scala @@ -3,13 +3,16 @@ package dsptools.dspmath case class RadPow(rad: Int, pow: Int) { + /** `r ^ p` */ def get: Int = BigInt(rad).pow(pow).toInt + /** Factorize i.e. rad = 4, pow = 3 -> Seq(4, 4, 4) */ def factorize: Seq[Int] = Seq.fill(pow)(rad) } case class Factorization(supportedRadsUnsorted: Seq[Seq[Int]]) { + /** Supported radices, MSD First */ private val supportedRads = supportedRadsUnsorted.map(_.sorted.reverse) @@ -22,17 +25,19 @@ case class Factorization(supportedRadsUnsorted: Seq[Seq[Int]]) { // Test if n can be factored by each of the supported radices (mod = 0) // Count # of times it can be factored var unfactorized = n - val radPows = for (primeGroup <- supportedRads) yield { for (rad <- primeGroup) yield { - var (mod, pow) = (0, 0) - while (mod == 0) { - mod = unfactorized % rad - if (mod == 0) { - pow = pow + 1 - unfactorized = unfactorized / rad + val radPows = for (primeGroup <- supportedRads) yield { + for (rad <- primeGroup) yield { + var (mod, pow) = (0, 0) + while (mod == 0) { + mod = unfactorized % rad + if (mod == 0) { + pow = pow + 1 + unfactorized = unfactorized / rad + } } + RadPow(rad, pow) } - RadPow(rad, pow) - }} + } // If n hasn't completely been factorized, then an unsupported radix is required require(unfactorized == 1, s"$n is invalid for supportedRads.") radPows @@ -82,4 +87,3 @@ case class Factorization(supportedRadsUnsorted: Seq[Seq[Int]]) { } } - diff --git a/src/main/scala/dsptools/misc/BitWidth.scala b/src/main/scala/dsptools/misc/BitWidth.scala index 956104e3..2b22d879 100644 --- a/src/main/scala/dsptools/misc/BitWidth.scala +++ b/src/main/scala/dsptools/misc/BitWidth.scala @@ -3,6 +3,7 @@ package dsptools.misc object BitWidth { + /** * Utility function that computes bits required for a number * @@ -10,7 +11,7 @@ object BitWidth { * @return */ def computeBits(n: BigInt): Int = { - n.bitLength + (if(n < 0) 1 else 0) + n.bitLength + (if (n < 0) 1 else 0) } /** @@ -23,14 +24,12 @@ object BitWidth { * @return minimum required bits for an SInt */ def requiredBitsForSInt(num: BigInt): Int = { - if(num == BigInt(0) || num == -BigInt(1)) { + if (num == BigInt(0) || num == -BigInt(1)) { 1 - } - else { + } else { if (num < 0) { computeBits(num) - } - else { + } else { computeBits(num) + 1 } } @@ -50,11 +49,10 @@ object BitWidth { * @return minimum required bits for an SInt */ def requiredBitsForUInt(num: BigInt): Int = { - if(num == BigInt(0)) { + if (num == BigInt(0)) { 1 - } - else { + } else { computeBits(num) } } -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/misc/DspException.scala b/src/main/scala/dsptools/misc/DspException.scala index 1b4f662a..cbcf5831 100644 --- a/src/main/scala/dsptools/misc/DspException.scala +++ b/src/main/scala/dsptools/misc/DspException.scala @@ -4,5 +4,4 @@ package dsptools import chisel3.ChiselException -case class DspException(message: String) extends ChiselException(message) { -} +case class DspException(message: String) extends ChiselException(message) {} diff --git a/src/main/scala/dsptools/misc/DspTesterUtilities.scala b/src/main/scala/dsptools/misc/DspTesterUtilities.scala index e8f330ba..a729ab14 100644 --- a/src/main/scala/dsptools/misc/DspTesterUtilities.scala +++ b/src/main/scala/dsptools/misc/DspTesterUtilities.scala @@ -6,7 +6,7 @@ import chisel3.{fromDoubleToLiteral => _, fromIntToBinaryPoint => _, _} import fixedpoint._ import dsptools.DspException import dsptools.numbers.{DspComplex, DspReal} -import chisel3.internal.InstanceId +import chisel3.InstanceId //scalastyle:off cyclomatic.complexity method.length object DspTesterUtilities { @@ -99,7 +99,7 @@ object DspTesterUtilities { case _ => s"${width}-bit F" } case r: DspReal => "R" - case u: UInt => s"${width}-bit U" + case u: UInt => s"${width}-bit U" case s: SInt => s"${width}-bit S" case c: DspComplex[_] => { val realInfo = bitInfo(c.real.asInstanceOf[Data]) @@ -115,7 +115,7 @@ object DspTesterUtilities { // Round value if data type is integer def roundData(data: Data, value: Double): Double = { data match { - case _: SInt | _: UInt => value.round.toDouble + case _: SInt | _: UInt => value.round.toDouble case _: DspReal | _: FixedPoint => value case _ => throw DspException("Invalid data type for rounding determination") } diff --git a/src/main/scala/dsptools/numbers/DspContext.scala b/src/main/scala/dsptools/numbers/DspContext.scala index d13f79f5..28f43d67 100644 --- a/src/main/scala/dsptools/numbers/DspContext.scala +++ b/src/main/scala/dsptools/numbers/DspContext.scala @@ -101,21 +101,21 @@ trait hasContext extends Any { } case class DspContext( - val overflowType: OverflowType = DspContext.defaultOverflowType, - val trimType: TrimType = DspContext.defaultTrimType, - val binaryPoint: Option[Int] = DspContext.defaultBinaryPoint, - val numBits: Option[Int] = DspContext.defaultNumBits, - val complexUse4Muls: Boolean = DspContext.defaultComplexUse4Muls, - val numMulPipes: Int = DspContext.defaultNumMulPipes, - val numAddPipes: Int = DspContext.defaultNumAddPipes, - val binaryPointGrowth: Int = DspContext.defaultBinaryPointGrowth) { + val overflowType: OverflowType = DspContext.defaultOverflowType, + val trimType: TrimType = DspContext.defaultTrimType, + val binaryPoint: Option[Int] = DspContext.defaultBinaryPoint, + val numBits: Option[Int] = DspContext.defaultNumBits, + val complexUse4Muls: Boolean = DspContext.defaultComplexUse4Muls, + val numMulPipes: Int = DspContext.defaultNumMulPipes, + val numAddPipes: Int = DspContext.defaultNumAddPipes, + val binaryPointGrowth: Int = DspContext.defaultBinaryPointGrowth) { require(numMulPipes >= 0, "# of pipeline registers for multiplication must be >= 0 ") require(numAddPipes >= 0, "# of pipeline registers for addition must be >= 0 ") require(binaryPointGrowth >= 0, "Binary point growth must be non-negative") numBits match { case Some(i) => require(i > 0, "# of bits must be > 0") - case _ => + case _ => } def complexMulPipe: Int = { diff --git a/src/main/scala/dsptools/numbers/algebra_types/Eq.scala b/src/main/scala/dsptools/numbers/algebra_types/Eq.scala index 78ec2af9..e6ab4d9f 100644 --- a/src/main/scala/dsptools/numbers/algebra_types/Eq.scala +++ b/src/main/scala/dsptools/numbers/algebra_types/Eq.scala @@ -2,7 +2,7 @@ package dsptools.numbers -import chisel3.{Data, Bool} +import chisel3.{Bool, Data} /** * Much of this is drawn from non/spire, but using Chisel Bools instead of @@ -16,11 +16,12 @@ import chisel3.{Data, Bool} * Moreover, `eqv` should form an equivalence relation. */ trait Eq[A <: Data] extends Any { + /** Returns `true` if `x` and `y` are equivalent, `false` otherwise. */ def eqv(x: A, y: A): Bool /** Returns `false` if `x` and `y` are equivalent, `true` otherwise. */ - def neqv(x: A, y:A ): Bool = !eqv(x, y) + def neqv(x: A, y: A): Bool = !eqv(x, y) /** * Constructs a new `Eq` instance for type `B` where 2 elements are @@ -38,8 +39,3 @@ object Eq { def by[A <: Data, B <: Data](f: A => B)(implicit e: Eq[B]): Eq[A] = new MappedEq(e)(f) } - - - - - diff --git a/src/main/scala/dsptools/numbers/algebra_types/Order.scala b/src/main/scala/dsptools/numbers/algebra_types/Order.scala index 56b2734c..02e17617 100644 --- a/src/main/scala/dsptools/numbers/algebra_types/Order.scala +++ b/src/main/scala/dsptools/numbers/algebra_types/Order.scala @@ -52,8 +52,8 @@ trait Order[A <: Data] extends Any with PartialOrder[A] { c.lt || c.eq } - def min(x: A, y: A): A = Mux(lt(x, y), x, y) - def max(x: A, y: A): A = Mux(gt(x, y), x, y) + def min(x: A, y: A): A = fixedpoint.shadow.Mux(lt(x, y), x, y) + def max(x: A, y: A): A = fixedpoint.shadow.Mux(gt(x, y), x, y) def compare(x: A, y: A): ComparisonBundle /** @@ -85,4 +85,3 @@ object Order { def compare(x: A, y: A): ComparisonBundle = f(x, y) } } - diff --git a/src/main/scala/dsptools/numbers/algebra_types/PartialOrder.scala b/src/main/scala/dsptools/numbers/algebra_types/PartialOrder.scala index f51c013c..f51e32b9 100644 --- a/src/main/scala/dsptools/numbers/algebra_types/PartialOrder.scala +++ b/src/main/scala/dsptools/numbers/algebra_types/PartialOrder.scala @@ -2,8 +2,8 @@ package dsptools.numbers -import chisel3.{Bool, Data, Mux} import chisel3.util.{Valid, ValidIO} +import chisel3.{Bool, Data} // Note: For type classing normal Chisel number data types like UInt, SInt, FixedPoint, etc. // you should *not* have to rely on PartialOrder (all comparisons to the same type are legal) @@ -33,16 +33,17 @@ import chisel3.util.{Valid, ValidIO} * false false = NaN (x and y cannot be compared) * true false = -1.0 (corresponds to x < y) * false true = 1.0 (corresponds to x > y) - * */ trait PartialOrder[A <: Data] extends Any with Eq[A] { self => + /** Result of comparing `x` with `y`. Returns ValidIO[ComparisonBundle] * with `valid` false if operands are not comparable. If operands are * comparable, `bits.lt` will be true if `x` < `y` and `bits.eq` will * be true if `x` = `y`` */ def partialCompare(x: A, y: A): ValidIO[ComparisonBundle] + /** Result of comparing `x` with `y`. Returns None if operands * are not comparable. If operands are comparable, returns Some[Int] * where the Int sign is: @@ -54,7 +55,7 @@ trait PartialOrder[A <: Data] extends Any with Eq[A] { /** Returns Some(x) if x <= y, Some(y) if x > y, otherwise None. */ def pmin(x: A, y: A): ValidIO[A] = { val c = partialCompare(x, y) - val value = Mux(c.bits.lt, x, y) + val value = fixedpoint.shadow.Mux(c.bits.lt, x, y) val ret = Valid(value) ret.valid := c.valid ret @@ -63,7 +64,7 @@ trait PartialOrder[A <: Data] extends Any with Eq[A] { /** Returns Some(x) if x >= y, Some(y) if x < y, otherwise None. */ def pmax(x: A, y: A): ValidIO[A] = { val c = partialCompare(x, y) - val value = Mux(!c.bits.lt, x, y) + val value = fixedpoint.shadow.Mux(!c.bits.lt, x, y) val ret = Valid(value) ret.valid := c.valid ret @@ -84,7 +85,7 @@ trait PartialOrder[A <: Data] extends Any with Eq[A] { } def gteqv(x: A, y: A): Bool = lteqv(y, x) - def gt(x: A, y: A): Bool = lt(y, x) + def gt(x: A, y: A): Bool = lt(y, x) /** * Defines a partial order on `B` by mapping `B` to `A` using `f` and using `A`s @@ -98,7 +99,8 @@ trait PartialOrder[A <: Data] extends Any with Eq[A] { def reverse: PartialOrder[A] = new ReversedPartialOrder(this) } -private[numbers] class MappedPartialOrder[A <: Data, B <: Data](partialOrder: PartialOrder[B])(f: A => B) extends PartialOrder[A] { +private[numbers] class MappedPartialOrder[A <: Data, B <: Data](partialOrder: PartialOrder[B])(f: A => B) + extends PartialOrder[A] { def partialCompare(x: A, y: A): ValidIO[ComparisonBundle] = partialOrder.partialCompare(f(x), f(y)) } diff --git a/src/main/scala/dsptools/numbers/algebra_types/Ring.scala b/src/main/scala/dsptools/numbers/algebra_types/Ring.scala index 530d4080..adca5880 100644 --- a/src/main/scala/dsptools/numbers/algebra_types/Ring.scala +++ b/src/main/scala/dsptools/numbers/algebra_types/Ring.scala @@ -10,8 +10,8 @@ object Ring { } trait Ring[A] extends Any with spire.algebra.Ring[A] { - def plusContext(f: A, g: A): A - def minusContext(f: A, g: A): A - def timesContext(f: A, g: A): A + def plusContext(f: A, g: A): A + def minusContext(f: A, g: A): A + def timesContext(f: A, g: A): A def negateContext(f: A): A } diff --git a/src/main/scala/dsptools/numbers/algebra_types/Signed.scala b/src/main/scala/dsptools/numbers/algebra_types/Signed.scala index e3aa0d73..7be2467a 100644 --- a/src/main/scala/dsptools/numbers/algebra_types/Signed.scala +++ b/src/main/scala/dsptools/numbers/algebra_types/Signed.scala @@ -3,7 +3,7 @@ package dsptools.numbers import chisel3.util.ShiftRegister -import chisel3.{Bool, Data, Mux} +import chisel3.{Bool, Data} import dsptools.hasContext /** @@ -17,6 +17,7 @@ import dsptools.hasContext * something has a positive sign. */ trait Signed[A] extends Any { + /** Returns Zero if `a` is 0, Positive if `a` is positive, and Negative is `a` is negative. */ def sign(a: A): Sign = Sign(signum(a)) @@ -28,11 +29,11 @@ trait Signed[A] extends Any { //noinspection ScalaStyle def context_abs(a: A): A - def isSignZero(a: A): Bool = signum(a).eq + def isSignZero(a: A): Bool = signum(a).eq def isSignPositive(a: A): Bool = !isSignZero(a) && !isSignNegative(a) def isSignNegative(a: A): Bool = signum(a).lt - def isSignNonZero(a: A): Bool = !isSignZero(a) + def isSignNonZero(a: A): Bool = !isSignZero(a) def isSignNonPositive(a: A): Bool = !isSignPositive(a) def isSignNonNegative(a: A): Bool = !isSignNegative(a) @@ -48,10 +49,13 @@ object Signed { private class OrderedRingIsSigned[A <: Data](implicit o: Order[A], r: Ring[A]) extends Signed[A] with hasContext { def signum(a: A): ComparisonBundle = o.compare(a, r.zero) def abs(a: A): A = { - Mux(signum(a).lt, r.negate(a), a) + fixedpoint.shadow.Mux(signum(a).lt, r.negate(a), a) } def context_abs(a: A): A = { - Mux(signum(ShiftRegister(a, context.numAddPipes)).lt, r.negateContext(a), ShiftRegister(a, context.numAddPipes)) + fixedpoint.shadow.Mux( + signum(ShiftRegister(a, context.numAddPipes)).lt, + r.negateContext(a), + ShiftRegister(a, context.numAddPipes) + ) } } - diff --git a/src/main/scala/dsptools/numbers/algebra_types/helpers/Sign.scala b/src/main/scala/dsptools/numbers/algebra_types/helpers/Sign.scala index ab8d122c..c882bf01 100644 --- a/src/main/scala/dsptools/numbers/algebra_types/helpers/Sign.scala +++ b/src/main/scala/dsptools/numbers/algebra_types/helpers/Sign.scala @@ -19,11 +19,11 @@ import scala.language.implicitConversions */ sealed class Sign(zeroInit: Option[Boolean] = None, negInit: Option[Boolean] = None) extends Bundle { // import Sign._ - val zero = zeroInit.map{_.B}.getOrElse(Bool()) + val zero = zeroInit.map { _.B }.getOrElse(Bool()) // ignore neg if zero is true - val neg = negInit.map{_.B}.getOrElse(Bool()) + val neg = negInit.map { _.B }.getOrElse(Bool()) - def unary_-(): Sign = Sign(this.zero, !this.neg) + def unary_- : Sign = Sign(this.zero, !this.neg) def *(that: Sign): Sign = Sign( this.zero || that.zero, @@ -35,7 +35,7 @@ sealed class Sign(zeroInit: Option[Boolean] = None, negInit: Option[Boolean] = N Sign(zero, if (evenPow) false.B else neg) } - // LSB indicates even or oddness -- only negative if this is negative and + // LSB indicates even or oddness -- only negative if this is negative and // it's raised by an odd power def **(that: UInt): Sign = Sign(this.zero, this.neg && that(0)) } @@ -46,18 +46,20 @@ object Sign { case object Negative extends Sign(Some(false), Some(true)) def apply(zero: Bool, neg: Bool): Sign = { - val zeroLit = zero.litOption.map{_ != BigInt(0)} - val negLit = neg.litOption.map{_ != BigInt(0)} + val zeroLit = zero.litOption.map { _ != BigInt(0) } + val negLit = neg.litOption.map { _ != BigInt(0) } val isLit = zeroLit.isDefined && negLit.isDefined - val wireWrapIfNotLit: Sign => Sign = s => if (isLit) { s } else Wire(s) + val wireWrapIfNotLit: Sign => Sign = s => + if (isLit) { s } + else Wire(s) val bundle = wireWrapIfNotLit( - new Sign(zeroInit=zeroLit, negInit=negLit) + new Sign(zeroInit = zeroLit, negInit = negLit) ) if (!zero.isLit) { bundle.zero := zero } if (!neg.isLit) { - bundle.neg := neg + bundle.neg := neg } bundle } @@ -74,19 +76,21 @@ object Sign { def combine(a: Sign, b: Sign): Sign = a * b override def sign(a: Sign): Sign = a - def signum(a: Sign): ComparisonBundle = ComparisonHelper(a.zero, a.neg) - def abs(a: Sign): Sign = if (a == Negative) Positive else a - def context_abs(a: Sign): Sign = if (a == Negative) Positive else a + def signum(a: Sign): ComparisonBundle = ComparisonHelper(a.zero, a.neg) + def abs(a: Sign): Sign = if (a == Negative) Positive else a + def context_abs(a: Sign): Sign = if (a == Negative) Positive else a def compare(x: Sign, y: Sign): ComparisonBundle = { - val eq = Mux(x.zero, + val eq = fixedpoint.shadow.Mux( + x.zero, // if x is zero, y must also be zero for equality y.zero, // if x is not zero, y must not be zero and must have the same sign !y.zero && (x.neg === y.neg) ) // lt only needs to be correct when eq not true - val lt = Mux(x.zero, + val lt = fixedpoint.shadow.Mux( + x.zero, // if x is zero, then true when y positive !y.zero && !y.neg, // if x is not zero, then true when x is negative and y not negative @@ -95,6 +99,10 @@ object Sign { ComparisonHelper(eq, lt) } + + override def reverse: SignAlgebra = new SignAlgebra { + override def compare(x: Sign, y: Sign): ComparisonBundle = super.compare(y, x) + } } implicit final val SignAlgebra = new SignAlgebra @@ -103,16 +111,17 @@ object Sign { Multiplicative(SignAlgebra) //scalastyle:off method.name - implicit def SignAction[A<: Data](implicit A: AdditiveGroup[A]): MultiplicativeAction[A, Sign] = + implicit def SignAction[A <: Data](implicit A: AdditiveGroup[A]): MultiplicativeAction[A, Sign] = new MultiplicativeAction[A, Sign] with hasContext { // Multiply a # by a sign def gtimesl(s: Sign, a: A): A = { - Mux(ShiftRegister(s.zero, context.numAddPipes), + fixedpoint.shadow.Mux( + ShiftRegister(s.zero, context.numAddPipes), ShiftRegister(A.zero, context.numAddPipes), - Mux(ShiftRegister(s.neg, context.numAddPipes), A.negate(a), ShiftRegister(a, context.numAddPipes)) + fixedpoint.shadow + .Mux(ShiftRegister(s.neg, context.numAddPipes), A.negate(a), ShiftRegister(a, context.numAddPipes)) ) } def gtimesr(a: A, s: Sign): A = gtimesl(s, a) } } - diff --git a/src/main/scala/dsptools/numbers/binary_types/BinaryRepresentation.scala b/src/main/scala/dsptools/numbers/binary_types/BinaryRepresentation.scala index 2c600afa..6a6ec752 100644 --- a/src/main/scala/dsptools/numbers/binary_types/BinaryRepresentation.scala +++ b/src/main/scala/dsptools/numbers/binary_types/BinaryRepresentation.scala @@ -2,18 +2,18 @@ package dsptools.numbers -import chisel3.{Data, UInt, Bool} +import chisel3.{Bool, Data, UInt} object BinaryRepresentation { def apply[A <: Data](implicit A: BinaryRepresentation[A]): BinaryRepresentation[A] = A } trait BinaryRepresentation[A <: Data] extends Any { - def shl(a: A, n: Int): A + def shl(a: A, n: Int): A def shl(a: A, n: UInt): A // For negative signed #'s, this is actually round to negative infinity - def shr(a: A, n: Int): A - def shr(a: A, n: UInt): A + def shr(a: A, n: Int): A + def shr(a: A, n: UInt): A def signBit(a: A): Bool // Rounds to zero (positive, negative consistent!) diff --git a/src/main/scala/dsptools/numbers/binary_types/NumberBits.scala b/src/main/scala/dsptools/numbers/binary_types/NumberBits.scala index 7ad03917..0c4c7d7a 100644 --- a/src/main/scala/dsptools/numbers/binary_types/NumberBits.scala +++ b/src/main/scala/dsptools/numbers/binary_types/NumberBits.scala @@ -4,16 +4,14 @@ package dsptools.numbers import chisel3._ -trait RealBits[A <: Data] extends Any with Real[A] with ChiselConvertableFrom[A] with BinaryRepresentation[A] { -} +trait RealBits[A <: Data] extends Any with Real[A] with ChiselConvertableFrom[A] with BinaryRepresentation[A] {} object RealBits { def apply[A <: Data](implicit A: RealBits[A]): RealBits[A] = A } -trait IntegerBits[A <: Data] extends Any with RealBits[A] with Integer[A] { -} +trait IntegerBits[A <: Data] extends Any with RealBits[A] with Integer[A] {} object IntegerBits { def apply[A <: Data](implicit A: IntegerBits[A]): IntegerBits[A] = A -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/numbers/blackbox_compatibility/DspRealVerilatorBB.scala b/src/main/scala/dsptools/numbers/blackbox_compatibility/DspRealVerilatorBB.scala index 1af728ae..b87b2a3f 100644 --- a/src/main/scala/dsptools/numbers/blackbox_compatibility/DspRealVerilatorBB.scala +++ b/src/main/scala/dsptools/numbers/blackbox_compatibility/DspRealVerilatorBB.scala @@ -58,7 +58,7 @@ class BBFEquals extends BlackboxTwoOperandBool class BBFNotEquals extends BlackboxTwoOperandBool -/** Math operations from IEEE.1364-2005 **/ +/** Math operations from IEEE.1364-2005 * */ class BBFLn extends BlackboxOneOperand class BBFLog10 extends BlackboxOneOperand diff --git a/src/main/scala/dsptools/numbers/blackbox_compatibility/TrigUtility.scala b/src/main/scala/dsptools/numbers/blackbox_compatibility/TrigUtility.scala index 2ea58fe1..ad18e9dd 100644 --- a/src/main/scala/dsptools/numbers/blackbox_compatibility/TrigUtility.scala +++ b/src/main/scala/dsptools/numbers/blackbox_compatibility/TrigUtility.scala @@ -14,7 +14,7 @@ object TrigUtility { // @ https://en.wikipedia.org/wiki/Bernoulli_number def bernoulli(n: Int): Double = { this.synchronized { - var temp: Array[Double] = Array.fill(n + 1)(0.0) + var temp: Array[Double] = Array.fill(n + 1)(0.0) for (m <- 0 to n) { temp(m) = 1.toDouble / (m + 1) for (j <- m to 1 by -1) { @@ -22,30 +22,34 @@ object TrigUtility { } } // Bn - temp(0) + temp(0) } } def factorial(n: Int): Int = (1 to n).product - def combination(n: Int, k: Int): Double = factorial(n).toDouble / factorial(k) / factorial(n - k) + def combination(n: Int, k: Int): Double = factorial(n).toDouble / factorial(k) / factorial(n - k) // See Taylor series for trig functions @ https://en.wikipedia.org/wiki/Taylor_series def sinCoeff(nmax: Int): Seq[(Double, Double)] = { - (0 to nmax) map { n => { - val fact = factorial(2 * n + 1) - val factOutOfBounds = fact / err - // If you divide by too large of a number, things go crazy - val scaleFactor = if (factOutOfBounds <= 1) 1.0 else fact.toDouble / err - val denom = if (factOutOfBounds <= 1) fact else err - (math.pow(-1, n) / denom, scaleFactor) - } } + (0 to nmax).map { n => + { + val fact = factorial(2 * n + 1) + val factOutOfBounds = fact / err + // If you divide by too large of a number, things go crazy + val scaleFactor = if (factOutOfBounds <= 1) 1.0 else fact.toDouble / err + val denom = if (factOutOfBounds <= 1) fact else err + (math.pow(-1, n) / denom, scaleFactor) + } + } } def cosCoeff(nmax: Int): Seq[Double] = { (0 to nmax).map(n => math.pow(-1, n) / factorial(2 * n)) } def tanCoeff(nmax: Int): Seq[Double] = { - (1 to nmax).map(n => bernoulli(2 * n) * math.pow(2, 2 * n) * (math.pow(2, 2 * n) - 1) * math.pow(-1, n - 1) / factorial(2 * n)) + (1 to nmax).map(n => + bernoulli(2 * n) * math.pow(2, 2 * n) * (math.pow(2, 2 * n) - 1) * math.pow(-1, n - 1) / factorial(2 * n) + ) } // Fast convergence of arctan (arcsin, arccos derived) @@ -54,12 +58,11 @@ object TrigUtility { // Is Even if (j % 2 == 0) { val i = j / 2 - def sumTerm(k : Int) = math.pow(-1, k) * combination(4 * m, 2 * k) + def sumTerm(k: Int) = math.pow(-1, k) * combination(4 * m, 2 * k) math.pow(-1, i + 1) * ((i + 1) to (2 * m)).map(k => sumTerm(k)).sum - } - else { + } else { val i = (j + 1) / 2 - def sumTerm(k : Int) = math.pow(-1, k) * combination(4 * m, 2 * k + 1) + def sumTerm(k: Int) = math.pow(-1, k) * combination(4 * m, 2 * k + 1) math.pow(-1, i + 1) * (i to (2 * m - 1)).map(k => sumTerm(k)).sum } } @@ -71,4 +74,4 @@ object TrigUtility { (0 to (4 * m - 2)).map(j => a(j, m) / math.pow(-1, m + 1) / math.pow(4, m) / (4 * m + j + 1)) } -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/numbers/chisel_concrete/DspComplex.scala b/src/main/scala/dsptools/numbers/chisel_concrete/DspComplex.scala index 924beece..65d7edbf 100644 --- a/src/main/scala/dsptools/numbers/chisel_concrete/DspComplex.scala +++ b/src/main/scala/dsptools/numbers/chisel_concrete/DspComplex.scala @@ -12,7 +12,7 @@ import scala.reflect.ClassTag object DspComplex { - def apply[T <: Data:Ring](gen: T): DspComplex[T] = { + def apply[T <: Data: Ring](gen: T): DspComplex[T] = { if (gen.isLit) throw DspException("Cannot use Lit in single argument DspComplex.apply") apply(gen.cloneType, gen.cloneType) } @@ -20,10 +20,10 @@ object DspComplex { // If real, imag are literals, the literals are carried through // In reality, real and imag should have the same type, so should be using single argument // apply if you aren't trying t create a Lit - def apply[T <: Data:Ring](real: T, imag: T): DspComplex[T] = { + def apply[T <: Data: Ring](real: T, imag: T): DspComplex[T] = { val newReal = if (real.isLit) real.cloneType else real val newImag = if (imag.isLit) imag.cloneType else imag - if(real.isLit && imag.isLit) { + if (real.isLit && imag.isLit) { new DspComplex(newReal, newImag).Lit(_.real -> real, _.imag -> imag) } else { new DspComplex(newReal, newImag) @@ -32,7 +32,7 @@ object DspComplex { // Needed for assigning to results of operations; should not use in user code for making wires // Assumes real, imag are not literals - def wire[T <: Data:Ring](real: T, imag: T): DspComplex[T] = { + def wire[T <: Data: Ring](real: T, imag: T): DspComplex[T] = { val result = Wire(DspComplex(real.cloneType, imag.cloneType)) result.real := real result.imag := imag @@ -41,30 +41,33 @@ object DspComplex { // Constant j // TODO(Paul): this call to wire() should be removed when chisel has literal bundles - def j[T <: Data:Ring] : DspComplex[T] = DspComplex(Ring[T].zero, Ring[T].one) + def j[T <: Data: Ring]: DspComplex[T] = DspComplex(Ring[T].zero, Ring[T].one) // Creates a DspComplex literal of type DspComplex[T] from a Breeze Complex // Note: when T is FixedPoint, the # of fractional bits is determined via DspContext - def apply[T <: Data:Ring:ConvertableTo](c: Complex): DspComplex[T] = { + def apply[T <: Data: Ring: ConvertableTo](c: Complex): DspComplex[T] = { DspComplex(ConvertableTo[T].fromDouble(c.real), ConvertableTo[T].fromDouble(c.imag)) } - // Creates a DspComplex literal where real and imaginary parts have type T (and binary point + // Creates a DspComplex literal where real and imaginary parts have type T (and binary point // determined by binaryPoint of t) - def proto[T <: Data:Ring:ConvertableTo](c: Complex, t: T): DspComplex[T] = { + def proto[T <: Data: Ring: ConvertableTo](c: Complex, t: T): DspComplex[T] = { DspComplex(ConvertableTo[T].fromDouble(c.real, t), ConvertableTo[T].fromDouble(c.imag, t)) } - // Creates a DspComplex literal where real and imaginary parts have type T (width/binary point + // Creates a DspComplex literal where real and imaginary parts have type T (width/binary point // determined by width/binaryPoint of t) - def protoWithFixedWidth[T <: Data:Ring:ConvertableTo](c: Complex, t: T): DspComplex[T] = { - DspComplex(ConvertableTo[T].fromDoubleWithFixedWidth(c.real, t), - ConvertableTo[T].fromDoubleWithFixedWidth(c.imag, t)) + def protoWithFixedWidth[T <: Data: Ring: ConvertableTo](c: Complex, t: T): DspComplex[T] = { + DspComplex( + ConvertableTo[T].fromDoubleWithFixedWidth(c.real, t), + ConvertableTo[T].fromDoubleWithFixedWidth(c.imag, t) + ) } } -class DspComplex[T <: Data:Ring](val real: T, val imag: T)(implicit val ct: ClassTag[DspComplex[T]]) extends Bundle - with ForceElementwiseConnect[DspComplex[T]] { - +class DspComplex[T <: Data: Ring](val real: T, val imag: T)(implicit val ct: ClassTag[DspComplex[T]]) + extends Bundle + with ForceElementwiseConnect[DspComplex[T]] { + // So old DSP code doesn't break def imaginary(dummy: Int = 0): T = imag diff --git a/src/main/scala/dsptools/numbers/chisel_concrete/DspReal.scala b/src/main/scala/dsptools/numbers/chisel_concrete/DspReal.scala index 819fb7d9..bca48a2c 100644 --- a/src/main/scala/dsptools/numbers/chisel_concrete/DspReal.scala +++ b/src/main/scala/dsptools/numbers/chisel_concrete/DspReal.scala @@ -11,7 +11,7 @@ class DspReal() extends Bundle { val node: UInt = Output(UInt(DspReal.underlyingWidth.W)) - private def oneOperandOperator(blackbox_gen: => BlackboxOneOperand) : DspReal = { + private def oneOperandOperator(blackbox_gen: => BlackboxOneOperand): DspReal = { val blackbox = blackbox_gen blackbox.io.in := node val out = Wire(DspReal()) @@ -19,7 +19,7 @@ class DspReal() extends Bundle { out } - private def twoOperandOperator(arg1: DspReal, blackbox_gen: => BlackboxTwoOperand) : DspReal = { + private def twoOperandOperator(arg1: DspReal, blackbox_gen: => BlackboxTwoOperand): DspReal = { val blackbox = blackbox_gen blackbox.io.in1 := node blackbox.io.in2 := arg1.node @@ -28,7 +28,7 @@ class DspReal() extends Bundle { out } - private def twoOperandBool(arg1: DspReal, blackbox_gen: => BlackboxTwoOperandBool) : Bool = { + private def twoOperandBool(arg1: DspReal, blackbox_gen: => BlackboxTwoOperandBool): Bool = { val blackbox = blackbox_gen blackbox.io.in1 := node blackbox.io.in2 := arg1.node @@ -37,71 +37,71 @@ class DspReal() extends Bundle { out } - def + (arg1: DspReal): DspReal = { + def +(arg1: DspReal): DspReal = { twoOperandOperator(arg1, Module(new BBFAdd())) } - def - (arg1: DspReal): DspReal = { + def -(arg1: DspReal): DspReal = { twoOperandOperator(arg1, Module(new BBFSubtract())) } - def * (arg1: DspReal): DspReal = { + def *(arg1: DspReal): DspReal = { twoOperandOperator(arg1, Module(new BBFMultiply())) } - def / (arg1: DspReal): DspReal = { + def /(arg1: DspReal): DspReal = { twoOperandOperator(arg1, Module(new BBFDivide())) } - def > (arg1: DspReal): Bool = { + def >(arg1: DspReal): Bool = { twoOperandBool(arg1, Module(new BBFGreaterThan())) } - def >= (arg1: DspReal): Bool = { + def >=(arg1: DspReal): Bool = { twoOperandBool(arg1, Module(new BBFGreaterThanEquals())) } - def < (arg1: DspReal): Bool = { + def <(arg1: DspReal): Bool = { twoOperandBool(arg1, Module(new BBFLessThan())) } - def <= (arg1: DspReal): Bool = { + def <=(arg1: DspReal): Bool = { twoOperandBool(arg1, Module(new BBFLessThanEquals())) } - def === (arg1: DspReal): Bool = { + def ===(arg1: DspReal): Bool = { twoOperandBool(arg1, Module(new BBFEquals())) } - def != (arg1: DspReal): Bool = { + def !=(arg1: DspReal): Bool = { twoOperandBool(arg1, Module(new BBFNotEquals())) } - def ln (dummy: Int = 0): DspReal = { + def ln(dummy: Int = 0): DspReal = { oneOperandOperator(Module(new BBFLn())) } - def log10 (dummy: Int = 0): DspReal = { + def log10(dummy: Int = 0): DspReal = { oneOperandOperator(Module(new BBFLog10())) } - def exp (dummy: Int = 0): DspReal = { + def exp(dummy: Int = 0): DspReal = { oneOperandOperator(Module(new BBFExp())) } - def sqrt (dummy: Int = 0): DspReal = { + def sqrt(dummy: Int = 0): DspReal = { oneOperandOperator(Module(new BBFSqrt())) } - def pow (arg1: DspReal): DspReal = { + def pow(arg1: DspReal): DspReal = { twoOperandOperator(arg1, Module(new BBFPow())) } - def floor (dummy: Int = 0): DspReal = { + def floor(dummy: Int = 0): DspReal = { oneOperandOperator(Module(new BBFFloor())) } - def ceil (dummy: Int = 0): DspReal = { + def ceil(dummy: Int = 0): DspReal = { oneOperandOperator(Module(new BBFCeil())) } @@ -131,21 +131,21 @@ class DspReal() extends Bundle { // Swept in increments of 0.0001pi, and got ~11 decimal digits of accuracy // Can add more half angle recursion for more precision //scalastyle:off magic.number - def sin (dummy: Int = 0): DspReal = { + def sin(dummy: Int = 0): DspReal = { if (backendIsVerilator) { - // Taylor series; Works best close to 0 (-pi/2, pi/2) -- so normalize! + // Taylor series; Works best close to 0 (-pi/2, pi/2) -- so normalize! def sinPiOver2(in: DspReal): DspReal = { val nmax = TrigUtility.numTaylorTerms - 1 //val xpow = (0 to nmax).map(n => in.pow(DspReal(2 * n + 1))) // Multiply by extra x later improves accuracy val xpow = (0 to nmax).map(n => in.pow(DspReal(2 * n))) // Break coefficient into two step process b/c of precision limitations - val terms = TrigUtility.sinCoeff(nmax).zip(xpow) map { case ((c, scale), x) => DspReal(c) * x / DspReal(scale)} + val terms = TrigUtility.sinCoeff(nmax).zip(xpow).map { case ((c, scale), x) => DspReal(c) * x / DspReal(scale) } terms.reduceRight(_ + _) * in } val num2Pi = (this / twoPi).truncate() // Repeats every 2*pi, so normalize to -pi, pi - val normalized2Pi= this - num2Pi * twoPi + val normalized2Pi = this - num2Pi * twoPi val temp1 = Mux(normalized2Pi > pi, normalized2Pi - twoPi, normalized2Pi) val normalizedPi = Mux(normalized2Pi < negPi, normalized2Pi + twoPi, temp1) val q2 = normalizedPi > halfPi @@ -153,7 +153,7 @@ class DspReal() extends Bundle { // sin(x + pi) = -sin(x) // sin(pi - x) = sin(x) val temp2 = Mux(q2, pi - normalizedPi, normalizedPi) - val normalizedHalfPi= Mux(q3, pi + normalizedPi, temp2) + val normalizedHalfPi = Mux(q3, pi + normalizedPi, temp2) // Half angle -> sin(x/2) = (-1)^(floor(x/(2pi)) * sqrt((1-cosx)/2)) // x negative -> sin(x/2) = -1 * sqrt((1-cosx)/2)) @@ -169,47 +169,46 @@ class DspReal() extends Bundle { val sinNegPiOver4Out = sinPiOver2((normalizedHalfPi + halfPi) / DspReal(2.0)) val sinPiOver2Out = sinPiOver2(normalizedHalfPi) - val outTemp1 = Mux( normalizedHalfPi > pi/DspReal(4), - one - DspReal(2) * sinPiOver4Out * sinPiOver4Out, - sinPiOver2Out) - val outTemp2 = Mux(normalizedHalfPi < pi/DspReal(-4), - DspReal(-1) * (one - DspReal(2) * sinNegPiOver4Out * sinNegPiOver4Out), - outTemp1) + val outTemp1 = + Mux(normalizedHalfPi > pi / DspReal(4), one - DspReal(2) * sinPiOver4Out * sinPiOver4Out, sinPiOver2Out) + val outTemp2 = Mux( + normalizedHalfPi < pi / DspReal(-4), + DspReal(-1) * (one - DspReal(2) * sinNegPiOver4Out * sinNegPiOver4Out), + outTemp1 + ) Mux(q3, zero - outTemp2, outTemp2) - } - else { + } else { oneOperandOperator(Module(new BBFSin())) } } - def cos (dummy: Int = 0): DspReal = { + def cos(dummy: Int = 0): DspReal = { if (backendIsVerilator) (this + halfPi).sin() else oneOperandOperator(Module(new BBFCos())) } // Swept input at 0.0001pi increments. For tan < 1e9, ~8 decimal digit precision (fractional) // WARNING: tan blows up (more likely to be wrong when abs is close to pi/2) - def tan (dummy: Int = 0): DspReal = { + def tan(dummy: Int = 0): DspReal = { if (backendIsVerilator) { def tanPiOver2(in: DspReal): DspReal = { - in.sin()/in.cos() + in.sin() / in.cos() } val numPi = (this / pi).truncate() // Repeats every pi, so normalize to -pi/2, pi/2 // tan(x + pi) = tan(x) - val normalizedPi= this - numPi * pi + val normalizedPi = this - numPi * pi val temp = Mux(normalizedPi > halfPi, normalizedPi - pi, normalizedPi) val normalizedHalfPi = Mux(normalizedPi < negHalfPi, normalizedPi + pi, temp) - + // Also note: tan(x) = 2*tan(x/2)/(1-tan^2(x/2)) tanPiOver2(normalizedHalfPi) - } - else oneOperandOperator(Module(new BBFTan())) + } else oneOperandOperator(Module(new BBFTan())) } // Correct to 9 decimal digits sweeping by 0.0001pi // See http://myweb.lmu.edu/hmedina/papers/reprintmonthly156-161-medina.pdf - def atan (dummy: Int = 0): DspReal = { + def atan(dummy: Int = 0): DspReal = { if (backendIsVerilator) { def arctanPiOver2(in: DspReal): DspReal = { val m = TrigUtility.atanM @@ -218,109 +217,105 @@ class DspReal() extends Bundle { // Move single multiply by x until later val xpow1 = (1 to (2 * m)).map(j => in.pow(DspReal(2 * j - 2))) val xpow2 = (0 to (4 * m - 2)).map(j => in.pow(DspReal(4 * m + j))) - val terms1 = TrigUtility.atanCoeff1(m).zip(xpow1) map { case (c, x) => DspReal(c) * x } - val terms2 = TrigUtility.atanCoeff2(m).zip(xpow2) map { case (c, x) => DspReal(c) * x } + val terms1 = TrigUtility.atanCoeff1(m).zip(xpow1).map { case (c, x) => DspReal(c) * x } + val terms2 = TrigUtility.atanCoeff2(m).zip(xpow2).map { case (c, x) => DspReal(c) * x } (terms1 ++ terms2).reduceRight(_ + _) * in } - val isNeg = this.signBit() + val isNeg = this.signBit // arctan(-x) = -arctan(x) val inTemp = this.abs() // arctan(x) = pi/2 - arctan(1/x) for x > 0 // Approximation accuracy in [0, 1] val outTemp = Mux(inTemp > one, halfPi - arctanPiOver2(one / inTemp), arctanPiOver2(inTemp)) Mux(isNeg, zero - outTemp, outTemp) - } - else oneOperandOperator(Module(new BBFATan())) + } else oneOperandOperator(Module(new BBFATan())) } // See https://en.wikipedia.org/wiki/Inverse_trigonometric_functions // Must be -1 <= x <= 1 - def asin (dummy: Int = 0): DspReal = { + def asin(dummy: Int = 0): DspReal = { if (backendIsVerilator) { val sqrtIn = one - (this * this) val atanIn = this / (one + sqrtIn.sqrt()) DspReal(2) * atanIn.atan() - } - else oneOperandOperator(Module(new BBFASin())) + } else oneOperandOperator(Module(new BBFASin())) } // Must be -1 <= x <= 1 - def acos (dummy: Int = 0): DspReal = { + def acos(dummy: Int = 0): DspReal = { if (backendIsVerilator) { halfPi - this.asin() - } - else oneOperandOperator(Module(new BBFACos())) + } else oneOperandOperator(Module(new BBFACos())) } // Output in the range (-pi, pi] // y.atan2(x) - def atan2 (arg1: DspReal): DspReal = { + def atan2(arg1: DspReal): DspReal = { if (backendIsVerilator) { val x = arg1 val y = this val atanArg = y / x val atanRes = atanArg.atan() val muxIn: Iterable[(Bool, DspReal)] = Iterable( - (x > zero) -> atanRes, - (x.signBit() && !y.signBit()) -> (atanRes + pi), - (x.signBit() && y.signBit()) -> (atanRes - pi), - (x === zero && y > zero) -> halfPi, - (x === zero && y.signBit()) -> negHalfPi, - (x === zero && y === zero) -> atanArg // undefined + (x > zero) -> atanRes, + (x.signBit && !y.signBit) -> (atanRes + pi), + (x.signBit && y.signBit) -> (atanRes - pi), + (x === zero && y > zero) -> halfPi, + (x === zero && y.signBit) -> negHalfPi, + (x === zero && y === zero) -> atanArg // undefined ) Mux1H(muxIn) - } - else twoOperandOperator(arg1, Module(new BBFATan2())) + } else twoOperandOperator(arg1, Module(new BBFATan2())) } - def hypot (arg1: DspReal): DspReal = { + def hypot(arg1: DspReal): DspReal = { if (backendIsVerilator) (this * this + arg1 * arg1).sqrt() else twoOperandOperator(arg1, Module(new BBFHypot())) } // See https://en.wikipedia.org/wiki/Hyperbolic_function - def sinh (dummy: Int = 0): DspReal = { + def sinh(dummy: Int = 0): DspReal = { if (backendIsVerilator) (this.exp() - (zero - this).exp()) / DspReal(2) else oneOperandOperator(Module(new BBFSinh())) } - def cosh (dummy: Int = 0): DspReal = { + def cosh(dummy: Int = 0): DspReal = { if (backendIsVerilator) (this.exp() + (zero - this).exp()) / DspReal(2) else oneOperandOperator(Module(new BBFCosh())) } - def tanh (dummy: Int = 0): DspReal = { + def tanh(dummy: Int = 0): DspReal = { if (backendIsVerilator) (this.exp() - (zero - this).exp()) / (this.exp() + (zero - this).exp()) else oneOperandOperator(Module(new BBFTanh())) } // Requires Breeze for testing: - def asinh (dummy: Int = 0): DspReal = { + def asinh(dummy: Int = 0): DspReal = { if (backendIsVerilator) ((this * this + one).sqrt() + this).ln() else oneOperandOperator(Module(new BBFASinh())) } // x >= 1 - def acosh (dummy: Int = 0): DspReal = { + def acosh(dummy: Int = 0): DspReal = { if (backendIsVerilator) ((this * this - one).sqrt() + this).ln() else oneOperandOperator(Module(new BBFACosh())) } // |x| < 1 - def atanh (dummy: Int = 0): DspReal = { + def atanh(dummy: Int = 0): DspReal = { if (backendIsVerilator) ((one + this) / (one - this)).ln() / DspReal(2) else oneOperandOperator(Module(new BBFATanh())) } - + // Not used directly -- there's an equivalent in the type classes (was causing some confusion) /* def intPart(dummy: Int = 0): DspReal = { oneOperandOperator(Module(new BBFIntPart())) } - */ + */ /** Returns this Real's value rounded to a signed integer. * Behavior on overflow (possible with large exponent terms) is undefined. @@ -350,8 +345,8 @@ object DspReal { */ def apply(value: Double): DspReal = { // See http://stackoverflow.com/questions/21212993/unsigned-variables-in-scala - def longAsUnsignedBigInt(in: Long): BigInt = (BigInt(in >>> 1) << 1) + (in & 1) - def doubleToBigInt(in: Double): BigInt = longAsUnsignedBigInt(java.lang.Double.doubleToRawLongBits(in)) + def longAsUnsignedBigInt(in: Long): BigInt = (BigInt(in >>> 1) << 1) + (in & 1) + def doubleToBigInt(in: Double): BigInt = longAsUnsignedBigInt(java.lang.Double.doubleToRawLongBits(in)) (new DspReal()).Lit(_.node -> doubleToBigInt(value).U) } diff --git a/src/main/scala/dsptools/numbers/chisel_concrete/RealTrig.scala b/src/main/scala/dsptools/numbers/chisel_concrete/RealTrig.scala index 059f693b..76f2d745 100644 --- a/src/main/scala/dsptools/numbers/chisel_concrete/RealTrig.scala +++ b/src/main/scala/dsptools/numbers/chisel_concrete/RealTrig.scala @@ -4,23 +4,23 @@ package dsptools.numbers // Make using these ops more like using math.opName object RealTrig { - def ln(x: DspReal) = x.ln() + def ln(x: DspReal) = x.ln() def log10(x: DspReal) = x.log10() - def exp(x: DspReal) = x.exp() - def sqrt(x: DspReal) = x.sqrt() - def pow(x: DspReal, n: DspReal) = x.pow(n) - def sin(x: DspReal) = x.sin() - def cos(x: DspReal) = x.cos() - def tan(x: DspReal) = x.tan() - def atan(x: DspReal) = x.atan() - def asin(x: DspReal) = x.asin() - def acos(x: DspReal) = x.acos() + def exp(x: DspReal) = x.exp() + def sqrt(x: DspReal) = x.sqrt() + def pow(x: DspReal, n: DspReal) = x.pow(n) + def sin(x: DspReal) = x.sin() + def cos(x: DspReal) = x.cos() + def tan(x: DspReal) = x.tan() + def atan(x: DspReal) = x.atan() + def asin(x: DspReal) = x.asin() + def acos(x: DspReal) = x.acos() def atan2(y: DspReal, x: DspReal) = y.atan2(x) def hypot(x: DspReal, y: DspReal) = x.hypot(y) - def sinh(x: DspReal) = x.sinh() - def cosh(x: DspReal) = x.cosh() - def tanh(x: DspReal) = x.tanh() + def sinh(x: DspReal) = x.sinh() + def cosh(x: DspReal) = x.cosh() + def tanh(x: DspReal) = x.tanh() def asinh(x: DspReal) = x.asinh() def acosh(x: DspReal) = x.acosh() def atanh(x: DspReal) = x.tanh() -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/numbers/chisel_types/DspComplexTypeClass.scala b/src/main/scala/dsptools/numbers/chisel_types/DspComplexTypeClass.scala index 1449863b..60d153ff 100644 --- a/src/main/scala/dsptools/numbers/chisel_types/DspComplexTypeClass.scala +++ b/src/main/scala/dsptools/numbers/chisel_types/DspComplexTypeClass.scala @@ -9,12 +9,12 @@ import implicits._ import chisel3.util.ShiftRegister import dsptools.DspException -abstract class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with hasContext { +abstract class DspComplexRing[T <: Data: Ring] extends Ring[DspComplex[T]] with hasContext { def plus(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = { DspComplex.wire(f.real + g.real, f.imag + g.imag) } def plusContext(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = { - DspComplex.wire(f.real context_+ g.real, f.imag context_+ g.imag) + DspComplex.wire(f.real.context_+(g.real), f.imag.context_+(g.imag)) } /** @@ -39,20 +39,20 @@ abstract class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with h def timesContext(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = { if (context.complexUse4Muls) DspComplex.wire( - (f.real context_* g.real) context_- (f.imag context_* g.imag), - (f.real context_* g.imag) context_+ (f.imag context_* g.real) + (f.real.context_*(g.real)).context_-(f.imag.context_*(g.imag)), + (f.real.context_*(g.imag)).context_+(f.imag.context_*(g.real)) ) else { val fRealDly = ShiftRegister(f.real, context.numAddPipes) val gRealDly = ShiftRegister(g.real, context.numAddPipes) val gImagDly = ShiftRegister(g.imag, context.numAddPipes) - val c_p_d = g.real context_+ g.imag - val a_p_b = f.real context_+ f.imag - val b_m_a = f.imag context_- f.real - val ac_p_ad = fRealDly context_* c_p_d - val ad_p_bd = a_p_b context_* gImagDly - val bc_m_ac = b_m_a context_* gRealDly - DspComplex.wire(ac_p_ad context_- ad_p_bd, ac_p_ad context_+ bc_m_ac) + val c_p_d = g.real.context_+(g.imag) + val a_p_b = f.real.context_+(f.imag) + val b_m_a = f.imag.context_-(f.real) + val ac_p_ad = fRealDly.context_*(c_p_d) + val ad_p_bd = a_p_b.context_*(gImagDly) + val bc_m_ac = b_m_a.context_*(gRealDly) + DspComplex.wire(ac_p_ad.context_-(ad_p_bd), ac_p_ad.context_+(bc_m_ac)) } } def one: DspComplex[T] = DspComplex(Ring[T].one, Ring[T].zero) @@ -67,7 +67,7 @@ abstract class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with h DspComplex.wire(f.real - g.real, f.imag - g.imag) } def minusContext(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = { - DspComplex.wire(f.real context_- g.real, f.imag context_- g.imag) + DspComplex.wire(f.real.context_-(g.real), f.imag.context_-(g.imag)) } } @@ -83,11 +83,11 @@ class DspComplexRingFixed extends DspComplexRing[FixedPoint] { override def plusForTimes(l: FixedPoint, r: FixedPoint): FixedPoint = l +& r } -class DspComplexRingData[T <: Data : Ring] extends DspComplexRing[T] { +class DspComplexRingData[T <: Data: Ring] extends DspComplexRing[T] { override protected def plusForTimes(l: T, r: T): T = l + r } -class DspComplexEq[T <: Data:Eq] extends Eq[DspComplex[T]] with hasContext { +class DspComplexEq[T <: Data: Eq] extends Eq[DspComplex[T]] with hasContext { override def eqv(x: DspComplex[T], y: DspComplex[T]): Bool = { Eq[T].eqv(x.real, y.real) && Eq[T].eqv(x.imag, y.imag) } @@ -96,24 +96,25 @@ class DspComplexEq[T <: Data:Eq] extends Eq[DspComplex[T]] with hasContext { } } -class DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] extends - BinaryRepresentation[DspComplex[T]] with hasContext { - override def shl(a: DspComplex[T], n: Int): DspComplex[T] = throw DspException("Can't shl on complex") - override def shl(a: DspComplex[T], n: UInt): DspComplex[T] = throw DspException("Can't shl on complex") - override def shr(a: DspComplex[T], n: Int): DspComplex[T] = throw DspException("Can't shr on complex") - override def shr(a: DspComplex[T], n: UInt): DspComplex[T] = throw DspException("Can't shr on complex") - override def div2(a: DspComplex[T], n: Int): DspComplex[T] = DspComplex.wire(a.real.div2(n), a.imag.div2(n)) - override def mul2(a: DspComplex[T], n: Int): DspComplex[T] = DspComplex.wire(a.real.mul2(n), a.imag.mul2(n)) - def clip(a: DspComplex[T], b: DspComplex[T]): DspComplex[T] = throw DspException("Can't clip on complex") - def signBit(a: DspComplex[T]): Bool = throw DspException("Can't get sign bit on complex") - def trimBinary(a: DspComplex[T], n: Option[Int]): DspComplex[T] = +class DspComplexBinaryRepresentation[T <: Data: Ring: BinaryRepresentation] + extends BinaryRepresentation[DspComplex[T]] + with hasContext { + override def shl(a: DspComplex[T], n: Int): DspComplex[T] = throw DspException("Can't shl on complex") + override def shl(a: DspComplex[T], n: UInt): DspComplex[T] = throw DspException("Can't shl on complex") + override def shr(a: DspComplex[T], n: Int): DspComplex[T] = throw DspException("Can't shr on complex") + override def shr(a: DspComplex[T], n: UInt): DspComplex[T] = throw DspException("Can't shr on complex") + override def div2(a: DspComplex[T], n: Int): DspComplex[T] = DspComplex.wire(a.real.div2(n), a.imag.div2(n)) + override def mul2(a: DspComplex[T], n: Int): DspComplex[T] = DspComplex.wire(a.real.mul2(n), a.imag.mul2(n)) + def clip(a: DspComplex[T], b: DspComplex[T]): DspComplex[T] = throw DspException("Can't clip on complex") + def signBit(a: DspComplex[T]): Bool = throw DspException("Can't get sign bit on complex") + def trimBinary(a: DspComplex[T], n: Option[Int]): DspComplex[T] = DspComplex.wire(BinaryRepresentation[T].trimBinary(a.real, n), BinaryRepresentation[T].trimBinary(a.imag, n)) } trait GenericDspComplexImpl { - implicit def DspComplexRingDataImpl[T<: Data:Ring] = new DspComplexRingData[T]() - implicit def DspComplexEq[T <: Data:Eq] = new DspComplexEq[T]() - implicit def DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] = + implicit def DspComplexRingDataImpl[T <: Data: Ring] = new DspComplexRingData[T]() + implicit def DspComplexEq[T <: Data: Eq] = new DspComplexEq[T]() + implicit def DspComplexBinaryRepresentation[T <: Data: Ring: BinaryRepresentation] = new DspComplexBinaryRepresentation[T]() } diff --git a/src/main/scala/dsptools/numbers/chisel_types/DspRealTypeClass.scala b/src/main/scala/dsptools/numbers/chisel_types/DspRealTypeClass.scala index c2d48946..6740b78f 100644 --- a/src/main/scala/dsptools/numbers/chisel_types/DspRealTypeClass.scala +++ b/src/main/scala/dsptools/numbers/chisel_types/DspRealTypeClass.scala @@ -2,15 +2,15 @@ package dsptools.numbers +import chisel3.util.ShiftRegister import chisel3.{fromDoubleToLiteral => _, fromIntToBinaryPoint => _, _} -import chisel3.util.{Cat, ShiftRegister} -import dsptools.{DspContext, NoTrim, hasContext} +import dsptools.{hasContext, DspContext, NoTrim} import fixedpoint._ import scala.language.implicitConversions trait DspRealRing extends Any with Ring[DspReal] with hasContext { - def one: DspReal = DspReal(1.0) + def one: DspReal = DspReal(1.0) def zero: DspReal = DspReal(0.0) def plus(f: DspReal, g: DspReal): DspReal = f + g def plusContext(f: DspReal, g: DspReal): DspReal = { @@ -20,9 +20,9 @@ trait DspRealRing extends Any with Ring[DspReal] with hasContext { def minusContext(f: DspReal, g: DspReal): DspReal = { ShiftRegister(f - g, context.numAddPipes) } - def negate(f: DspReal): DspReal = minus(zero, f) + def negate(f: DspReal): DspReal = minus(zero, f) def negateContext(f: DspReal): DspReal = minusContext(zero, f) - def times(f: DspReal, g: DspReal): DspReal = f * g + def times(f: DspReal, g: DspReal): DspReal = f * g def timesContext(f: DspReal, g: DspReal): DspReal = { ShiftRegister(f * g, context.numMulPipes) } @@ -32,11 +32,11 @@ trait DspRealOrder extends Any with Order[DspReal] with hasContext { override def compare(x: DspReal, y: DspReal): ComparisonBundle = { ComparisonHelper(x === y, x < y) } - override def eqv(x: DspReal, y: DspReal): Bool = x === y - override def neqv(x: DspReal, y:DspReal): Bool = x != y - override def lt(x: DspReal, y: DspReal): Bool = x < y + override def eqv(x: DspReal, y: DspReal): Bool = x === y + override def neqv(x: DspReal, y: DspReal): Bool = x != y + override def lt(x: DspReal, y: DspReal): Bool = x < y override def lteqv(x: DspReal, y: DspReal): Bool = x <= y - override def gt(x: DspReal, y: DspReal): Bool = x > y + override def gt(x: DspReal, y: DspReal): Bool = x > y override def gteqv(x: DspReal, y: DspReal): Bool = x >= y // min, max depends on lt, gt & mux } @@ -54,8 +54,8 @@ trait DspRealSigned extends Any with Signed[DspReal] with DspRealRing with hasCo ) } - override def isSignZero(a: DspReal): Bool = a === DspReal(0.0) - override def isSignNegative(a:DspReal): Bool = a < DspReal(0.0) + override def isSignZero(a: DspReal): Bool = a === DspReal(0.0) + override def isSignNegative(a: DspReal): Bool = a < DspReal(0.0) // isSignPositive, isSignNonZero, isSignNonPositive, isSignNonNegative derived from above (!) } @@ -66,26 +66,30 @@ trait DspRealIsReal extends Any with IsReal[DspReal] with DspRealOrder with DspR def context_ceil(a: DspReal): DspReal = { ShiftRegister(a, context.numAddPipes).ceil() } - def floor(a: DspReal): DspReal = a.floor() + def floor(a: DspReal): DspReal = a.floor() def isWhole(a: DspReal): Bool = a === round(a) // Round *half up* -- Different from System Verilog definition! (where half is rounded away from zero) // according to 5.7.2 (http://www.ece.uah.edu/~gaede/cpe526/2012%20System%20Verilog%20Language%20Reference%20Manual.pdf) def round(a: DspReal): DspReal = a.round() def truncate(a: DspReal): DspReal = { - Mux(ShiftRegister(a, context.numAddPipes) < DspReal(0.0), context_ceil(a), floor(ShiftRegister(a, context.numAddPipes))) + Mux( + ShiftRegister(a, context.numAddPipes) < DspReal(0.0), + context_ceil(a), + floor(ShiftRegister(a, context.numAddPipes)) + ) } } trait ConvertableToDspReal extends ConvertableTo[DspReal] with hasContext { - def fromShort(n: Short): DspReal = fromInt(n.toInt) - def fromByte(n: Byte): DspReal = fromInt(n.toInt) - def fromInt(n: Int): DspReal = fromBigInt(BigInt(n)) - def fromFloat(n: Float): DspReal = fromDouble(n.toDouble) - def fromBigDecimal(n: BigDecimal): DspReal = fromDouble(n.doubleValue) - def fromLong(n: Long): DspReal = fromBigInt(BigInt(n)) - def fromType[B](n: B)(implicit c: ConvertableFrom[B]): DspReal = fromDouble(c.toDouble(n)) - def fromBigInt(n: BigInt): DspReal = DspReal(n.doubleValue) - def fromDouble(n: Double): DspReal = DspReal(n) + def fromShort(n: Short): DspReal = fromInt(n.toInt) + def fromByte(n: Byte): DspReal = fromInt(n.toInt) + def fromInt(n: Int): DspReal = fromBigInt(BigInt(n)) + def fromFloat(n: Float): DspReal = fromDouble(n.toDouble) + def fromBigDecimal(n: BigDecimal): DspReal = fromDouble(n.doubleValue) + def fromLong(n: Long): DspReal = fromBigInt(BigInt(n)) + def fromType[B](n: B)(implicit c: ConvertableFrom[B]): DspReal = fromDouble(c.toDouble(n)) + def fromBigInt(n: BigInt): DspReal = DspReal(n.doubleValue) + def fromDouble(n: Double): DspReal = DspReal(n) override def fromDouble(d: Double, a: DspReal): DspReal = fromDouble(d) // Ignores width override def fromDoubleWithFixedWidth(d: Double, a: DspReal): DspReal = fromDouble(d) @@ -119,14 +123,20 @@ trait BinaryRepresentationDspReal extends BinaryRepresentation[DspReal] with has override def div2(a: DspReal, n: Int): DspReal = a / DspReal(math.pow(2, n)) // Used purely for fixed point precision adjustment -- just passes DspReal through def trimBinary(a: DspReal, n: Option[Int]): DspReal = a - } +} -trait DspRealReal extends DspRealRing with DspRealIsReal with ConvertableToDspReal with - ConvertableFromDspReal with BinaryRepresentationDspReal with RealBits[DspReal] with hasContext { - def signBit(a: DspReal): Bool = isSignNegative(a) - override def fromInt(n: Int): DspReal = super[ConvertableToDspReal].fromInt(n) - override def fromBigInt(n: BigInt): DspReal = super[ConvertableToDspReal].fromBigInt(n) - def intPart(a: DspReal): SInt = truncate(a).toSInt() +trait DspRealReal + extends DspRealRing + with DspRealIsReal + with ConvertableToDspReal + with ConvertableFromDspReal + with BinaryRepresentationDspReal + with RealBits[DspReal] + with hasContext { + def signBit(a: DspReal): Bool = isSignNegative(a) + override def fromInt(n: Int): DspReal = super[ConvertableToDspReal].fromInt(n) + override def fromBigInt(n: BigInt): DspReal = super[ConvertableToDspReal].fromBigInt(n) + def intPart(a: DspReal): SInt = truncate(a).toSInt() // WARNING: Beware of overflow(!) def asFixed(a: DspReal, proto: FixedPoint): FixedPoint = { require(proto.binaryPoint.known, "Binary point must be known for DspReal -> FixedPoint") @@ -141,6 +151,6 @@ trait DspRealReal extends DspRealRing with DspRealIsReal with ConvertableToDspRe } } -trait DspRealImpl { +trait DspRealImpl { implicit object DspRealRealImpl extends DspRealReal } diff --git a/src/main/scala/dsptools/numbers/chisel_types/FixedPointTypeClass.scala b/src/main/scala/dsptools/numbers/chisel_types/FixedPointTypeClass.scala index 726ca6b5..128dc6d1 100644 --- a/src/main/scala/dsptools/numbers/chisel_types/FixedPointTypeClass.scala +++ b/src/main/scala/dsptools/numbers/chisel_types/FixedPointTypeClass.scala @@ -16,14 +16,14 @@ import scala.language.implicitConversions */ trait FixedPointRing extends Any with Ring[FixedPoint] with hasContext { def zero: FixedPoint = 0.0.F(0.BP) - def one: FixedPoint= 1.0.F(0.BP) + def one: FixedPoint = 1.0.F(0.BP) def plus(f: FixedPoint, g: FixedPoint): FixedPoint = f + g def plusContext(f: FixedPoint, g: FixedPoint): FixedPoint = { // TODO: Saturating mux should be outside of ShiftRegister val sum = context.overflowType match { case Grow => f +& g case Wrap => f +% g - case _ => throw DspException("Saturating add hasn't been implemented") + case _ => throw DspException("Saturating add hasn't been implemented") } ShiftRegister(sum, context.numAddPipes) } @@ -32,11 +32,11 @@ trait FixedPointRing extends Any with Ring[FixedPoint] with hasContext { val diff = context.overflowType match { case Grow => f -& g case Wrap => f -% g - case _ => throw DspException("Saturating subtractor hasn't been implemented") + case _ => throw DspException("Saturating subtractor hasn't been implemented") } ShiftRegister(diff, context.numAddPipes) } - def negate(f: FixedPoint): FixedPoint = -f + def negate(f: FixedPoint): FixedPoint = -f def negateContext(f: FixedPoint): FixedPoint = minus(zero, f) def times(f: FixedPoint, g: FixedPoint): FixedPoint = f * g @@ -48,11 +48,11 @@ trait FixedPointOrder extends Any with Order[FixedPoint] with hasContext { override def compare(x: FixedPoint, y: FixedPoint): ComparisonBundle = { ComparisonHelper(x === y, x < y) } - override def eqv(x: FixedPoint, y: FixedPoint): Bool = x === y - override def neqv(x: FixedPoint, y:FixedPoint): Bool = x =/= y - override def lt(x: FixedPoint, y: FixedPoint): Bool = x < y + override def eqv(x: FixedPoint, y: FixedPoint): Bool = x === y + override def neqv(x: FixedPoint, y: FixedPoint): Bool = x =/= y + override def lt(x: FixedPoint, y: FixedPoint): Bool = x < y override def lteqv(x: FixedPoint, y: FixedPoint): Bool = x <= y - override def gt(x: FixedPoint, y: FixedPoint): Bool = x > y + override def gt(x: FixedPoint, y: FixedPoint): Bool = x > y override def gteqv(x: FixedPoint, y: FixedPoint): Bool = x >= y // min, max depends on lt, gt & mux } @@ -65,11 +65,12 @@ trait FixedPointSigned extends Any with Signed[FixedPoint] with hasContext { trait FixedPointIsReal extends Any with IsReal[FixedPoint] with FixedPointOrder with FixedPointSigned with hasContext { // Chop off fractional bits --> round to negative infinity - def floor(a: FixedPoint): FixedPoint = a.setBinaryPoint(0) + def floor(a: FixedPoint): FixedPoint = a.setBinaryPoint(0) def isWhole(a: FixedPoint): Bool = a === floor(a) // Truncate = round towards zero (integer part without fractional bits) def truncate(a: FixedPoint): FixedPoint = { - shadow.Mux(isSignNegative(ShiftRegister(a, context.numAddPipes)), + shadow.Mux( + isSignNegative(ShiftRegister(a, context.numAddPipes)), ceil(a), floor(ShiftRegister(a, context.numAddPipes)) ) @@ -78,14 +79,14 @@ trait FixedPointIsReal extends Any with IsReal[FixedPoint] with FixedPointOrder } trait ConvertableToFixedPoint extends ConvertableTo[FixedPoint] with hasContext { - def fromShort(n: Short): FixedPoint = fromInt(n.toInt) - def fromByte(n: Byte): FixedPoint = fromInt(n.toInt) - def fromInt(n: Int): FixedPoint = fromBigInt(BigInt(n)) - def fromFloat(n: Float): FixedPoint = fromDouble(n.toDouble) + def fromShort(n: Short): FixedPoint = fromInt(n.toInt) + def fromByte(n: Byte): FixedPoint = fromInt(n.toInt) + def fromInt(n: Int): FixedPoint = fromBigInt(BigInt(n)) + def fromFloat(n: Float): FixedPoint = fromDouble(n.toDouble) def fromBigDecimal(n: BigDecimal): FixedPoint = fromDouble(n.doubleValue) - def fromLong(n: Long): FixedPoint = fromBigInt(BigInt(n)) - def fromType[B](n: B)(implicit c: ConvertableFrom[B]): FixedPoint = fromDouble(c.toDouble(n)) - def fromBigInt(n: BigInt): FixedPoint = n.doubleValue.F(0.BP) + def fromLong(n: Long): FixedPoint = fromBigInt(BigInt(n)) + def fromType[B](n: B)(implicit c: ConvertableFrom[B]): FixedPoint = fromDouble(c.toDouble(n)) + def fromBigInt(n: BigInt): FixedPoint = n.doubleValue.F(0.BP) // If no binary point is specified, use the default one provided by DspContext // TODO: Should you instead be specifying a max width so you can get the most resolution for a given width? def fromDouble(n: Double): FixedPoint = n.F(context.binaryPoint.getOrElse(0).BP) @@ -106,15 +107,15 @@ trait ConvertableFromFixedPoint extends ChiselConvertableFrom[FixedPoint] with h // intPart depends on truncate // asReal depends on shifting fractional bits up override def asFixed(a: FixedPoint): FixedPoint = a - def asFixed(a: FixedPoint, proto: FixedPoint): FixedPoint = asFixed(a) + def asFixed(a: FixedPoint, proto: FixedPoint): FixedPoint = asFixed(a) } trait BinaryRepresentationFixedPoint extends BinaryRepresentation[FixedPoint] with hasContext { - def shl(a: FixedPoint, n: Int): FixedPoint = a << n + def shl(a: FixedPoint, n: Int): FixedPoint = a << n def shl(a: FixedPoint, n: UInt): FixedPoint = a << n // Note: This rounds to negative infinity (smallest abs. value for negative #'s is -LSB) - def shr(a: FixedPoint, n: Int): FixedPoint = a >> n + def shr(a: FixedPoint, n: Int): FixedPoint = a >> n def shr(a: FixedPoint, n: UInt): FixedPoint = a >> n // mul2 consistent with shl @@ -139,10 +140,16 @@ trait BinaryRepresentationFixedPoint extends BinaryRepresentation[FixedPoint] wi } // trimBinary below for access to ring ops - } +} -trait FixedPointReal extends FixedPointRing with FixedPointIsReal with ConvertableToFixedPoint with - ConvertableFromFixedPoint with BinaryRepresentationFixedPoint with RealBits[FixedPoint] with hasContext { +trait FixedPointReal + extends FixedPointRing + with FixedPointIsReal + with ConvertableToFixedPoint + with ConvertableFromFixedPoint + with BinaryRepresentationFixedPoint + with RealBits[FixedPoint] + with hasContext { def clip(a: FixedPoint, b: FixedPoint): FixedPoint = ??? @@ -150,60 +157,83 @@ trait FixedPointReal extends FixedPointRing with FixedPointIsReal with Convertab // TODO: Support other modes? n match { case None => a - case Some(b) => context.trimType match { - case NoTrim => a - case RoundDown => a.setBinaryPoint(b) - case RoundUp => { - val addAmt = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) - shadow.Mux((a === a.setBinaryPoint(b)), a.setBinaryPoint(b), plus(a.setBinaryPoint(b), addAmt)) - } - case RoundTowardsZero => { - val addAmt = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) - val valueForNegativeNum = shadow.Mux((a === a.setBinaryPoint(b)), a.setBinaryPoint(b), plus(a.setBinaryPoint(b), addAmt)) - shadow.Mux(isSignNegative(a), valueForNegativeNum, a.setBinaryPoint(b)) - } - case RoundTowardsInfinity => { - val addAmt = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) - val valueForPositiveNum = shadow.Mux((a === a.setBinaryPoint(b)), a.setBinaryPoint(b), plus(a.setBinaryPoint(b), addAmt)) - shadow.Mux(isSignNegative(a), a.setBinaryPoint(b), valueForPositiveNum) - } - case RoundHalfDown => { - val addAmt1 = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) - val addAmt2 = math.pow(2, -(b+1)).F((b+1).BP) // shr(1.0.F((b+1).BP),(b+1)) - shadow.Mux((a > plus(a.setBinaryPoint(b), addAmt2)), plus(a.setBinaryPoint(b), addAmt1), a.setBinaryPoint(b)) - } - case RoundHalfUp => { - val roundBp = b + 1 - val addAmt = math.pow(2, -roundBp).F(roundBp.BP) - plus(a, addAmt).setBinaryPoint(b) + case Some(b) => + context.trimType match { + case NoTrim => a + case RoundDown => a.setBinaryPoint(b) + case RoundUp => { + val addAmt = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) + shadow.Mux((a === a.setBinaryPoint(b)), a.setBinaryPoint(b), plus(a.setBinaryPoint(b), addAmt)) + } + case RoundTowardsZero => { + val addAmt = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) + val valueForNegativeNum = + shadow.Mux((a === a.setBinaryPoint(b)), a.setBinaryPoint(b), plus(a.setBinaryPoint(b), addAmt)) + shadow.Mux(isSignNegative(a), valueForNegativeNum, a.setBinaryPoint(b)) + } + case RoundTowardsInfinity => { + val addAmt = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) + val valueForPositiveNum = + shadow.Mux((a === a.setBinaryPoint(b)), a.setBinaryPoint(b), plus(a.setBinaryPoint(b), addAmt)) + shadow.Mux(isSignNegative(a), a.setBinaryPoint(b), valueForPositiveNum) + } + case RoundHalfDown => { + val addAmt1 = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) + val addAmt2 = math.pow(2, -(b + 1)).F((b + 1).BP) // shr(1.0.F((b+1).BP),(b+1)) + shadow.Mux( + (a > plus(a.setBinaryPoint(b), addAmt2)), + plus(a.setBinaryPoint(b), addAmt1), + a.setBinaryPoint(b) + ) + } + case RoundHalfUp => { + val roundBp = b + 1 + val addAmt = math.pow(2, -roundBp).F(roundBp.BP) + plus(a, addAmt).setBinaryPoint(b) + } + case RoundHalfTowardsZero => { + val addAmt1 = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) + val addAmt2 = math.pow(2, -(b + 1)).F((b + 1).BP) // shr(1.0.F((b+1).BP),(b+1)) + val valueForPositiveNum = shadow.Mux( + (a > plus(a.setBinaryPoint(b), addAmt2)), + plus(a.setBinaryPoint(b), addAmt1), + a.setBinaryPoint(b) + ) + shadow.Mux(isSignNegative(a), plus(a, addAmt2).setBinaryPoint(b), valueForPositiveNum) + } + case RoundHalfTowardsInfinity => { + val roundBp = b + 1 + val addAmt = math.pow(2, -roundBp).F(roundBp.BP) + shadow.Mux( + isSignNegative(a) && (a === a.setBinaryPoint(roundBp)), + a.setBinaryPoint(b), + plus(a, addAmt).setBinaryPoint(b) + ) + } + case RoundHalfToEven => { + require(b > 0, "Binary point of input fixed point number must be larger than zero when trimming") + val roundBp = b + 1 + val checkIfEvenBp = b - 1 + val addAmt = math.pow(2, -roundBp).F(roundBp.BP) + shadow.Mux( + (a.setBinaryPoint(checkIfEvenBp) === a.setBinaryPoint(b)) && (a === a.setBinaryPoint(roundBp)), + a.setBinaryPoint(b), + plus(a, addAmt).setBinaryPoint(b) + ) + } + case RoundHalfToOdd => { + require(b > 0, "Binary point of input fixed point number must be larger than zero when trimming") + val roundBp = b + 1 + val checkIfOddBp = b - 1 + val addAmt = math.pow(2, -roundBp).F(roundBp.BP) + shadow.Mux( + (a.setBinaryPoint(checkIfOddBp) =/= a.setBinaryPoint(b)) && (a === a.setBinaryPoint(roundBp)), + a.setBinaryPoint(b), + plus(a, addAmt).setBinaryPoint(b) + ) + } + case _ => throw DspException("Desired trim type not implemented!") } - case RoundHalfTowardsZero => { - val addAmt1 = math.pow(2, -b).F(b.BP) // shr(1.0.F(b.BP),b) - val addAmt2 = math.pow(2, -(b+1)).F((b+1).BP) // shr(1.0.F((b+1).BP),(b+1)) - val valueForPositiveNum = shadow.Mux((a > plus(a.setBinaryPoint(b), addAmt2)), plus(a.setBinaryPoint(b), addAmt1), a.setBinaryPoint(b)) - shadow.Mux(isSignNegative(a), plus(a, addAmt2).setBinaryPoint(b), valueForPositiveNum) - } - case RoundHalfTowardsInfinity => { - val roundBp = b + 1 - val addAmt = math.pow(2, -roundBp).F(roundBp.BP) - shadow.Mux(isSignNegative(a) && (a === a.setBinaryPoint(roundBp)), a.setBinaryPoint(b), plus(a, addAmt).setBinaryPoint(b)) - } - case RoundHalfToEven => { - require(b > 0, "Binary point of input fixed point number must be larger than zero when trimming") - val roundBp = b + 1 - val checkIfEvenBp = b - 1 - val addAmt = math.pow(2, -roundBp).F(roundBp.BP) - shadow.Mux((a.setBinaryPoint(checkIfEvenBp) === a.setBinaryPoint(b)) && (a === a.setBinaryPoint(roundBp)), a.setBinaryPoint(b), plus(a, addAmt).setBinaryPoint(b)) - } - case RoundHalfToOdd => { - require(b > 0, "Binary point of input fixed point number must be larger than zero when trimming") - val roundBp = b + 1 - val checkIfOddBp = b - 1 - val addAmt = math.pow(2, -roundBp).F(roundBp.BP) - shadow.Mux((a.setBinaryPoint(checkIfOddBp) =/= a.setBinaryPoint(b)) && (a === a.setBinaryPoint(roundBp)), a.setBinaryPoint(b), plus(a, addAmt).setBinaryPoint(b)) - } - case _ => throw DspException("Desired trim type not implemented!") - } } } @@ -213,7 +243,7 @@ trait FixedPointReal extends FixedPointRing with FixedPointIsReal with Convertab val outTemp = ShiftRegister(f * g, context.numMulPipes) val newBP = (f.binaryPoint, g.binaryPoint) match { case (KnownBinaryPoint(i), KnownBinaryPoint(j)) => Some(i.max(j) + context.binaryPointGrowth) - case (_, _) => None + case (_, _) => None } trimBinary(outTemp, newBP) } @@ -222,8 +252,8 @@ trait FixedPointReal extends FixedPointRing with FixedPointIsReal with Convertab ComparisonHelper(a === zero, a < zero) } override def isSignZero(a: FixedPoint): Bool = a === zero - override def isSignNegative(a:FixedPoint): Bool = { - if (a.widthKnown) a(a.getWidth-1) + override def isSignNegative(a: FixedPoint): Bool = { + if (a.widthKnown) a(a.getWidth - 1) else a < zero } @@ -232,7 +262,8 @@ trait FixedPointReal extends FixedPointRing with FixedPointIsReal with Convertab shadow.Mux( isWhole(ShiftRegister(a, context.numAddPipes)), floor(ShiftRegister(a, context.numAddPipes)), - plusContext(floor(a), one)) + plusContext(floor(a), one) + ) } def context_ceil(a: FixedPoint): FixedPoint = ceil(a) @@ -242,7 +273,7 @@ trait FixedPointReal extends FixedPointRing with FixedPointIsReal with Convertab def signBit(a: FixedPoint): Bool = isSignNegative(a) // fromFixedPoint also included in Ring - override def fromInt(n: Int): FixedPoint = super[ConvertableToFixedPoint].fromInt(n) + override def fromInt(n: Int): FixedPoint = super[ConvertableToFixedPoint].fromInt(n) override def fromBigInt(n: BigInt): FixedPoint = super[ConvertableToFixedPoint].fromBigInt(n) // Overflow only on most negative def abs(a: FixedPoint): FixedPoint = { @@ -252,7 +283,8 @@ trait FixedPointReal extends FixedPointRing with FixedPointIsReal with Convertab shadow.Mux( isSignNegative(ShiftRegister(a, context.numAddPipes)), super[FixedPointRing].minusContext(zero, a), - ShiftRegister(a, context.numAddPipes)) + ShiftRegister(a, context.numAddPipes) + ) } def intPart(a: FixedPoint): SInt = truncate(a).asSInt @@ -262,7 +294,7 @@ trait FixedPointReal extends FixedPointRing with FixedPointIsReal with Convertab require(a.binaryPoint.known, "Binary point must be known for asReal") val n = a.binaryPoint.get val normalizedInt = a << n - DspReal(floor(normalizedInt).asSInt)/DspReal((1 << n).toDouble) + DspReal(floor(normalizedInt).asSInt) / DspReal((1 << n).toDouble) } } diff --git a/src/main/scala/dsptools/numbers/chisel_types/SIntTypeClass.scala b/src/main/scala/dsptools/numbers/chisel_types/SIntTypeClass.scala index 1db6ee7f..4d6c3edc 100644 --- a/src/main/scala/dsptools/numbers/chisel_types/SIntTypeClass.scala +++ b/src/main/scala/dsptools/numbers/chisel_types/SIntTypeClass.scala @@ -4,7 +4,7 @@ package dsptools.numbers import chisel3.{fromDoubleToLiteral => _, fromIntToBinaryPoint => _, _} import chisel3.util.{Cat, ShiftRegister} -import dsptools.{DspContext, DspException, Grow, NoTrim, Saturate, Wrap, hasContext} +import dsptools.{hasContext, DspContext, DspException, Grow, NoTrim, Saturate, Wrap} import fixedpoint._ import scala.language.implicitConversions @@ -14,14 +14,14 @@ import scala.language.implicitConversions */ trait SIntRing extends Any with Ring[SInt] with hasContext { def zero: SInt = 0.S - def one: SInt = 1.S + def one: SInt = 1.S def plus(f: SInt, g: SInt): SInt = f + g def plusContext(f: SInt, g: SInt): SInt = { // TODO: Saturating mux should be outside of ShiftRegister val sum = context.overflowType match { case Grow => f +& g case Wrap => f +% g - case _ => throw DspException("Saturating add hasn't been implemented") + case _ => throw DspException("Saturating add hasn't been implemented") } ShiftRegister(sum, context.numAddPipes) } @@ -30,7 +30,7 @@ trait SIntRing extends Any with Ring[SInt] with hasContext { val diff = context.overflowType match { case Grow => f -& g case Wrap => f -% g - case _ => throw DspException("Saturating subtractor hasn't been implemented") + case _ => throw DspException("Saturating subtractor hasn't been implemented") } ShiftRegister(diff, context.numAddPipes) } @@ -45,18 +45,18 @@ trait SIntRing extends Any with Ring[SInt] with hasContext { def timesContext(f: SInt, g: SInt): SInt = { // TODO: Overflow via ranging in FIRRTL? ShiftRegister(f * g, context.numMulPipes) - } + } } trait SIntOrder extends Any with Order[SInt] with hasContext { override def compare(x: SInt, y: SInt): ComparisonBundle = { ComparisonHelper(x === y, x < y) } - override def eqv(x: SInt, y: SInt): Bool = x === y - override def neqv(x: SInt, y:SInt): Bool = x =/= y - override def lt(x: SInt, y: SInt): Bool = x < y + override def eqv(x: SInt, y: SInt): Bool = x === y + override def neqv(x: SInt, y: SInt): Bool = x =/= y + override def lt(x: SInt, y: SInt): Bool = x < y override def lteqv(x: SInt, y: SInt): Bool = x <= y - override def gt(x: SInt, y: SInt): Bool = x > y + override def gt(x: SInt, y: SInt): Bool = x > y override def gteqv(x: SInt, y: SInt): Bool = x >= y // min, max depends on lt, gt & mux } @@ -66,8 +66,8 @@ trait SIntSigned extends Any with Signed[SInt] with hasContext { ComparisonHelper(a === 0.S, a < 0.S) } override def isSignZero(a: SInt): Bool = a === 0.S - override def isSignNegative(a:SInt): Bool = { - if (a.widthKnown) a(a.getWidth-1) + override def isSignNegative(a: SInt): Bool = { + if (a.widthKnown) a(a.getWidth - 1) else a < 0.S } // isSignPositive, isSignNonZero, isSignNonPositive, isSignNonNegative derived from above (!) @@ -88,24 +88,24 @@ trait SIntIsReal extends Any with IsIntegral[SInt] with SIntOrder with SIntSigne trait ConvertableToSInt extends ConvertableTo[SInt] with hasContext { // Note: Double converted to Int via round first! - def fromShort(n: Short): SInt = fromInt(n.toInt) - def fromByte(n: Byte): SInt = fromInt(n.toInt) - def fromInt(n: Int): SInt = fromBigInt(BigInt(n)) - def fromFloat(n: Float): SInt = fromDouble(n.toDouble) + def fromShort(n: Short): SInt = fromInt(n.toInt) + def fromByte(n: Byte): SInt = fromInt(n.toInt) + def fromInt(n: Int): SInt = fromBigInt(BigInt(n)) + def fromFloat(n: Float): SInt = fromDouble(n.toDouble) def fromBigDecimal(n: BigDecimal): SInt = fromDouble(n.doubleValue) - def fromLong(n: Long): SInt = fromBigInt(BigInt(n)) - def fromType[B](n: B)(implicit c: ConvertableFrom[B]): SInt = fromBigInt(c.toBigInt(n)) - def fromBigInt(n: BigInt): SInt = n.S - def fromDouble(n: Double): SInt = n.round.toInt.S + def fromLong(n: Long): SInt = fromBigInt(BigInt(n)) + def fromType[B](n: B)(implicit c: ConvertableFrom[B]): SInt = fromBigInt(c.toBigInt(n)) + def fromBigInt(n: BigInt): SInt = n.S + def fromDouble(n: Double): SInt = n.round.toInt.S // Second argument needed for fixed pt binary point (unused here) override def fromDouble(d: Double, a: SInt): SInt = fromDouble(d) override def fromDoubleWithFixedWidth(d: Double, a: SInt): SInt = { require(a.widthKnown, "SInt width not known!") - val intVal = d.round.toInt + val intVal = d.round.toInt val intBits = BigInt(intVal).bitLength + 1 require(intBits <= a.getWidth, "Lit can't fit in prototype SInt bitwidth") intVal.asSInt(a.getWidth.W) - } + } } trait ConvertableFromSInt extends ChiselConvertableFrom[SInt] with hasContext { @@ -113,30 +113,36 @@ trait ConvertableFromSInt extends ChiselConvertableFrom[SInt] with hasContext { // Converts to FixedPoint with 0 fractional bits (Note: proto only used for real) override def asFixed(a: SInt): FixedPoint = a.asFixedPoint(0.BP) - def asFixed(a: SInt, proto: FixedPoint): FixedPoint = asFixed(a) + def asFixed(a: SInt, proto: FixedPoint): FixedPoint = asFixed(a) // Converts to (signed) DspReal def asReal(a: SInt): DspReal = DspReal(a) } trait BinaryRepresentationSInt extends BinaryRepresentation[SInt] with hasContext { def clip(a: SInt, b: SInt): SInt = ??? - def shl(a: SInt, n: Int): SInt = a << n - def shl(a: SInt, n: UInt): SInt = a << n + def shl(a: SInt, n: Int): SInt = a << n + def shl(a: SInt, n: UInt): SInt = a << n // Note: This rounds to negative infinity (smallest abs. value for negative #'s is -1) - def shr(a: SInt, n: Int): SInt = a >> n + def shr(a: SInt, n: Int): SInt = a >> n def shr(a: SInt, n: UInt): SInt = a >> n // Doesn't affect anything except FixedPoint (no such thing as negative n) - override def trimBinary(a: SInt, n: Int): SInt = a - def trimBinary(a: SInt, n: Option[Int]): SInt = a + override def trimBinary(a: SInt, n: Int): SInt = a + def trimBinary(a: SInt, n: Option[Int]): SInt = a // mul2 consistent with shl // signBit relies on Signed, div2 relies on ChiselConvertableFrom - } +} -trait SIntInteger extends SIntRing with SIntIsReal with ConvertableToSInt with - ConvertableFromSInt with BinaryRepresentationSInt with IntegerBits[SInt] with hasContext { +trait SIntInteger + extends SIntRing + with SIntIsReal + with ConvertableToSInt + with ConvertableFromSInt + with BinaryRepresentationSInt + with IntegerBits[SInt] + with hasContext { def signBit(a: SInt): Bool = isSignNegative(a) // fromSInt also included in Ring - override def fromInt(n: Int): SInt = super[ConvertableToSInt].fromInt(n) + override def fromInt(n: Int): SInt = super[ConvertableToSInt].fromInt(n) override def fromBigInt(n: BigInt): SInt = super[ConvertableToSInt].fromBigInt(n) // Overflow only on most negative def abs(a: SInt): SInt = Mux(isSignNegative(a), super[SIntRing].minus(0.S, a), a) diff --git a/src/main/scala/dsptools/numbers/chisel_types/UIntTypeClass.scala b/src/main/scala/dsptools/numbers/chisel_types/UIntTypeClass.scala index ebfebbdc..b8c0f99e 100644 --- a/src/main/scala/dsptools/numbers/chisel_types/UIntTypeClass.scala +++ b/src/main/scala/dsptools/numbers/chisel_types/UIntTypeClass.scala @@ -4,7 +4,7 @@ package dsptools.numbers import chisel3.{fromDoubleToLiteral => _, fromIntToBinaryPoint => _, _} import chisel3.util.{Cat, ShiftRegister} -import dsptools.{DspContext, DspException, Grow, Saturate, Wrap, hasContext} +import dsptools.{hasContext, DspContext, DspException, Grow, Saturate, Wrap} import fixedpoint._ import scala.language.implicitConversions @@ -14,14 +14,14 @@ import scala.language.implicitConversions */ trait UIntRing extends Any with Ring[UInt] with hasContext { def zero: UInt = 0.U - def one: UInt = 1.U + def one: UInt = 1.U def plus(f: UInt, g: UInt): UInt = f + g def plusContext(f: UInt, g: UInt): UInt = { // TODO: Saturating mux should be outside of ShiftRegister val sum = context.overflowType match { case Grow => f +& g case Wrap => f +% g - case _ => throw DspException("Saturating add hasn't been implemented") + case _ => throw DspException("Saturating add hasn't been implemented") } ShiftRegister(sum, context.numAddPipes) } @@ -30,28 +30,28 @@ trait UIntRing extends Any with Ring[UInt] with hasContext { val diff = context.overflowType match { case Grow => throw DspException("OverflowType Grow is not supported for UInt subtraction") case Wrap => f -% g - case _ => throw DspException("Saturating subtractor hasn't been implemented") + case _ => throw DspException("Saturating subtractor hasn't been implemented") } ShiftRegister(diff.asUInt, context.numAddPipes) } - def negate(f: UInt): UInt = -f + def negate(f: UInt): UInt = -f def negateContext(f: UInt): UInt = throw DspException("Can't negate UInt and get UInt") - def times(f: UInt, g: UInt): UInt = f * g + def times(f: UInt, g: UInt): UInt = f * g def timesContext(f: UInt, g: UInt): UInt = { // TODO: Overflow via ranging in FIRRTL? ShiftRegister(f * g, context.numMulPipes) - } + } } trait UIntOrder extends Any with Order[UInt] with hasContext { override def compare(x: UInt, y: UInt): ComparisonBundle = { ComparisonHelper(x === y, x < y) } - override def eqv(x: UInt, y: UInt): Bool = x === y - override def neqv(x: UInt, y:UInt): Bool = x =/= y - override def lt(x: UInt, y: UInt): Bool = x < y + override def eqv(x: UInt, y: UInt): Bool = x === y + override def neqv(x: UInt, y: UInt): Bool = x =/= y + override def lt(x: UInt, y: UInt): Bool = x < y override def lteqv(x: UInt, y: UInt): Bool = x <= y - override def gt(x: UInt, y: UInt): Bool = x > y + override def gt(x: UInt, y: UInt): Bool = x > y override def gteqv(x: UInt, y: UInt): Bool = x >= y // min, max depends on lt, gt } @@ -60,18 +60,18 @@ trait UIntSigned extends Any with Signed[UInt] with hasContext { def signum(a: UInt): ComparisonBundle = { ComparisonHelper(a === 0.U, a < 0.U) } - def abs(a: UInt): UInt = a // UInts are unsigned! - def context_abs(a: UInt): UInt = a // UInts are unsigned! - override def isSignZero(a: UInt): Bool = a === 0.U + def abs(a: UInt): UInt = a // UInts are unsigned! + def context_abs(a: UInt): UInt = a // UInts are unsigned! + override def isSignZero(a: UInt): Bool = a === 0.U override def isSignPositive(a: UInt): Bool = !isSignZero(a) override def isSignNegative(a: UInt): Bool = false.B // isSignNonZero, isSignNonPositive, isSignNonNegative derived from above (!) } trait UIntIsReal extends Any with IsIntegral[UInt] with UIntOrder with UIntSigned with hasContext { - // In IsIntegral: ceil, floor, round, truncate (from IsReal) already defined as itself; + // In IsIntegral: ceil, floor, round, truncate (from IsReal) already defined as itself; // isWhole always true - + // Unsure what happens if you have a zero-width wire def isOdd(a: UInt): Bool = a(0) // isEven derived from isOdd @@ -82,27 +82,27 @@ trait UIntIsReal extends Any with IsIntegral[UInt] with UIntOrder with UIntSigne trait ConvertableToUInt extends ConvertableTo[UInt] with hasContext { // Note: Double converted to Int via round first! - def fromShort(n: Short): UInt = fromInt(n.toInt) - def fromByte(n: Byte): UInt = fromInt(n.toInt) - def fromInt(n: Int): UInt = fromBigInt(BigInt(n)) - def fromFloat(n: Float): UInt = fromDouble(n.toDouble) + def fromShort(n: Short): UInt = fromInt(n.toInt) + def fromByte(n: Byte): UInt = fromInt(n.toInt) + def fromInt(n: Int): UInt = fromBigInt(BigInt(n)) + def fromFloat(n: Float): UInt = fromDouble(n.toDouble) def fromBigDecimal(n: BigDecimal): UInt = fromDouble(n.doubleValue) - def fromLong(n: Long): UInt = fromBigInt(BigInt(n)) - def fromType[B](n: B)(implicit c: ConvertableFrom[B]): UInt = fromBigInt(c.toBigInt(n)) + def fromLong(n: Long): UInt = fromBigInt(BigInt(n)) + def fromType[B](n: B)(implicit c: ConvertableFrom[B]): UInt = fromBigInt(c.toBigInt(n)) def fromBigInt(n: BigInt): UInt = { require(n >= 0, "Literal to UInt needs to be >= 0") n.U } def fromDouble(n: Double): UInt = { require(n >= 0, "Double literal to UInt needs to be >= 0") - n.round.toInt.U + n.round.toInt.U } // Second argument needed for fixed pt binary point (unused here) override def fromDouble(d: Double, a: UInt): UInt = fromDouble(d) override def fromDoubleWithFixedWidth(d: Double, a: UInt): UInt = { require(a.widthKnown, "UInt width not known!") require(d >= 0, "Double literal to UInt needs to be >= 0") - val intVal = d.round.toInt + val intVal = d.round.toInt val intBits = BigInt(intVal).bitLength require(intBits <= a.getWidth, "Lit can't fit in prototype UInt bitwidth") intVal.asUInt(a.getWidth.W) @@ -115,29 +115,35 @@ trait ConvertableFromUInt extends ChiselConvertableFrom[UInt] with hasContext { // Converts to FixedPoint with 0 fractional bits (second arg only used for DspReal) override def asFixed(a: UInt): FixedPoint = intPart(a).asFixedPoint(0.BP) - def asFixed(a: UInt, proto: FixedPoint): FixedPoint = asFixed(a) + def asFixed(a: UInt, proto: FixedPoint): FixedPoint = asFixed(a) // Converts to (signed) DspReal def asReal(a: UInt): DspReal = DspReal(intPart(a)) } trait BinaryRepresentationUInt extends BinaryRepresentation[UInt] with hasContext { - def shl(a: UInt, n: Int): UInt = a << n - def shl(a: UInt, n: UInt): UInt = a << n - def shr(a: UInt, n: Int): UInt = a >> n - def shr(a: UInt, n: UInt): UInt = a >> n + def shl(a: UInt, n: Int): UInt = a << n + def shl(a: UInt, n: UInt): UInt = a << n + def shr(a: UInt, n: Int): UInt = a >> n + def shr(a: UInt, n: UInt): UInt = a >> n def clip(a: UInt, n: UInt): UInt = ??? // Ignores negative trims (n not used for anything except Fixed) - override def trimBinary(a: UInt, n: Int): UInt = a - def trimBinary(a: UInt, n: Option[Int]): UInt = a + override def trimBinary(a: UInt, n: Int): UInt = a + def trimBinary(a: UInt, n: Option[Int]): UInt = a // signBit relies on Signed // mul2, div2 consistent with shl, shr - } +} -trait UIntInteger extends UIntRing with UIntIsReal with ConvertableToUInt with - ConvertableFromUInt with BinaryRepresentationUInt with IntegerBits[UInt] with hasContext { +trait UIntInteger + extends UIntRing + with UIntIsReal + with ConvertableToUInt + with ConvertableFromUInt + with BinaryRepresentationUInt + with IntegerBits[UInt] + with hasContext { def signBit(a: UInt): Bool = isSignNegative(a) // fromUInt also included in Ring - override def fromInt(n: Int): UInt = super[ConvertableToUInt].fromInt(n) + override def fromInt(n: Int): UInt = super[ConvertableToUInt].fromInt(n) override def fromBigInt(n: BigInt): UInt = super[ConvertableToUInt].fromBigInt(n) } diff --git a/src/main/scala/dsptools/numbers/convertible_types/ConvertableTo.scala b/src/main/scala/dsptools/numbers/convertible_types/ConvertableTo.scala index e1246b03..a719438d 100644 --- a/src/main/scala/dsptools/numbers/convertible_types/ConvertableTo.scala +++ b/src/main/scala/dsptools/numbers/convertible_types/ConvertableTo.scala @@ -9,6 +9,6 @@ object ConvertableTo { } trait ConvertableTo[A <: Data] extends Any with spire.math.ConvertableTo[A] { - def fromDouble(d: Double, a: A): A + def fromDouble(d: Double, a: A): A def fromDoubleWithFixedWidth(d: Double, a: A): A -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/numbers/implicits/AllOps.scala b/src/main/scala/dsptools/numbers/implicits/AllOps.scala index 5f1487ad..46e8f640 100644 --- a/src/main/scala/dsptools/numbers/implicits/AllOps.scala +++ b/src/main/scala/dsptools/numbers/implicits/AllOps.scala @@ -22,103 +22,103 @@ final class EqOps[A <: Data](lhs: A)(implicit ev: Eq[A]) { } final class PartialOrderOps[A <: Data](lhs: A)(implicit ev: PartialOrder[A]) { - def >(rhs: A): Bool = macro Ops.binop[A, Bool] + def >(rhs: A): Bool = macro Ops.binop[A, Bool] def >=(rhs: A): Bool = macro Ops.binop[A, Bool] - def <(rhs: A): Bool = macro Ops.binop[A, Bool] + def <(rhs: A): Bool = macro Ops.binop[A, Bool] def <=(rhs: A): Bool = macro Ops.binop[A, Bool] def partialCompare(rhs: A): Double = macro Ops.binop[A, Double] - def tryCompare(rhs: A): Option[Int] = macro Ops.binop[A, Option[Int]] - def pmin(rhs: A): Option[A] = macro Ops.binop[A, A] - def pmax(rhs: A): Option[A] = macro Ops.binop[A, A] + def tryCompare(rhs: A): Option[Int] = macro Ops.binop[A, Option[Int]] + def pmin(rhs: A): Option[A] = macro Ops.binop[A, A] + def pmax(rhs: A): Option[A] = macro Ops.binop[A, A] - def >(rhs: Int)(implicit ev1: Ring[A]): Bool = macro Ops.binopWithLift[Int, Ring[A], A] + def >(rhs: Int)(implicit ev1: Ring[A]): Bool = macro Ops.binopWithLift[Int, Ring[A], A] def >=(rhs: Int)(implicit ev1: Ring[A]): Bool = macro Ops.binopWithLift[Int, Ring[A], A] - def <(rhs: Int)(implicit ev1: Ring[A]): Bool = macro Ops.binopWithLift[Int, Ring[A], A] + def <(rhs: Int)(implicit ev1: Ring[A]): Bool = macro Ops.binopWithLift[Int, Ring[A], A] def <=(rhs: Int)(implicit ev1: Ring[A]): Bool = macro Ops.binopWithLift[Int, Ring[A], A] - def >(rhs: Double)(implicit ev1: Field[A]): Bool = macro Ops.binopWithLift[Int, Field[A], A] + def >(rhs: Double)(implicit ev1: Field[A]): Bool = macro Ops.binopWithLift[Int, Field[A], A] def >=(rhs: Double)(implicit ev1: Field[A]): Bool = macro Ops.binopWithLift[Int, Field[A], A] - def <(rhs: Double)(implicit ev1: Field[A]): Bool = macro Ops.binopWithLift[Int, Field[A], A] + def <(rhs: Double)(implicit ev1: Field[A]): Bool = macro Ops.binopWithLift[Int, Field[A], A] def <=(rhs: Double)(implicit ev1: Field[A]): Bool = macro Ops.binopWithLift[Int, Field[A], A] - def >(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Bool = (c.toNumber(lhs) > rhs).B + def >(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Bool = (c.toNumber(lhs) > rhs).B def >=(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Bool = (c.toNumber(lhs) >= rhs).B - def <(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Bool = (c.toNumber(lhs) < rhs).B + def <(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Bool = (c.toNumber(lhs) < rhs).B def <=(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Bool = (c.toNumber(lhs) <= rhs).B } final class OrderOps[A <: Data](lhs: A)(implicit ev: Order[A]) { def compare(rhs: A): ComparisonBundle = macro Ops.binop[A, ComparisonBundle] - def min(rhs: A): A = macro Ops.binop[A, A] - def max(rhs: A): A = macro Ops.binop[A, A] + def min(rhs: A): A = macro Ops.binop[A, A] + def max(rhs: A): A = macro Ops.binop[A, A] def compare(rhs: Int)(implicit ev1: Ring[A]): Int = macro Ops.binopWithLift[Int, Ring[A], A] - def min(rhs: Int)(implicit ev1: Ring[A]): A = macro Ops.binopWithLift[Int, Ring[A], A] - def max(rhs: Int)(implicit ev1: Ring[A]): A = macro Ops.binopWithLift[Int, Ring[A], A] + def min(rhs: Int)(implicit ev1: Ring[A]): A = macro Ops.binopWithLift[Int, Ring[A], A] + def max(rhs: Int)(implicit ev1: Ring[A]): A = macro Ops.binopWithLift[Int, Ring[A], A] def compare(rhs: Double)(implicit ev1: Field[A]): Int = macro Ops.binopWithLift[Int, Field[A], A] - def min(rhs: Double)(implicit ev1: Field[A]): A = macro Ops.binopWithLift[Int, Field[A], A] - def max(rhs: Double)(implicit ev1: Field[A]): A = macro Ops.binopWithLift[Int, Field[A], A] + def min(rhs: Double)(implicit ev1: Field[A]): A = macro Ops.binopWithLift[Int, Field[A], A] + def max(rhs: Double)(implicit ev1: Field[A]): A = macro Ops.binopWithLift[Int, Field[A], A] - def compare(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Int = c.toNumber(lhs) compare rhs - def min(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Number = c.toNumber(lhs) min rhs - def max(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Number = c.toNumber(lhs) max rhs + def compare(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Int = c.toNumber(lhs).compare(rhs) + def min(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Number = c.toNumber(lhs).min(rhs) + def max(rhs: spire.math.Number)(implicit c: ConvertableFrom[A]): Number = c.toNumber(lhs).max(rhs) } final class SignedOps[A: Signed](lhs: A) { - def abs(): A = macro Ops.unop[A] - def context_abs(): A = macro Ops.unop[A] - def sign(): Sign = macro Ops.unop[Sign] - def signum(): Int = macro Ops.unop[Int] - - def isSignZero(): Bool = macro Ops.unop[Bool] - def isSignPositive(): Bool = macro Ops.unop[Bool] - def isSignNegative(): Bool = macro Ops.unop[Bool] - - def isSignNonZero(): Bool = macro Ops.unop[Bool] - def isSignNonPositive(): Bool = macro Ops.unop[Bool] - def isSignNonNegative(): Bool = macro Ops.unop[Bool] + def abs: A = macro Ops.unop[A] + def context_abs: A = macro Ops.unop[A] + def sign: Sign = macro Ops.unop[Sign] + def signum: Int = macro Ops.unop[Int] + + def isSignZero: Bool = macro Ops.unop[Bool] + def isSignPositive: Bool = macro Ops.unop[Bool] + def isSignNegative: Bool = macro Ops.unop[Bool] + + def isSignNonZero: Bool = macro Ops.unop[Bool] + def isSignNonPositive: Bool = macro Ops.unop[Bool] + def isSignNonNegative: Bool = macro Ops.unop[Bool] } final class IsRealOps[A <: Data](lhs: A)(implicit ev: IsReal[A]) { - def isWhole(): Bool = macro Ops.unop[Bool] - def ceil(): A = macro Ops.unop[A] - def context_ceil(): A = macro Ops.unop[A] - def floor(): A = macro Ops.unop[A] - def round(): A = macro Ops.unop[A] - def truncate(): A = ev.truncate(lhs) + def isWhole: Bool = macro Ops.unop[Bool] + def ceil: A = macro Ops.unop[A] + def context_ceil: A = macro Ops.unop[A] + def floor: A = macro Ops.unop[A] + def round: A = macro Ops.unop[A] + def truncate: A = ev.truncate(lhs) } class IsIntegerOps[A <: Data](lhs: A)(implicit ev: IsIntegral[A]) { def mod(rhs: A): A = ev.mod(lhs, rhs) - def %(rhs: A): A = mod(rhs) - def isOdd(): Bool = ev.isOdd(lhs) - def isEven(): Bool = ev.isEven(lhs) + def %(rhs: A): A = mod(rhs) + def isOdd: Bool = ev.isOdd(lhs) + def isEven: Bool = ev.isEven(lhs) } class ConvertableToOps[A <: Data](lhs: A)(implicit ev: ConvertableTo[A]) { - def fromInt(i: Int): A = fromDouble(i.toDouble) - def fromIntWithFixedWidth(i: Int): A = fromDoubleWithFixedWidth(i.toDouble) - def fromDouble(d: Double): A = ev.fromDouble(d, lhs) + def fromInt(i: Int): A = fromDouble(i.toDouble) + def fromIntWithFixedWidth(i: Int): A = fromDoubleWithFixedWidth(i.toDouble) + def fromDouble(d: Double): A = ev.fromDouble(d, lhs) def fromDoubleWithFixedWidth(d: Double): A = ev.fromDoubleWithFixedWidth(d, lhs) } class ChiselConvertableFromOps[A <: Data](lhs: A)(implicit ev: ChiselConvertableFrom[A]) { - def intPart(): SInt = ev.intPart(lhs) - def asFixed(): FixedPoint = ev.asFixed(lhs) + def intPart: SInt = ev.intPart(lhs) + def asFixed: FixedPoint = ev.asFixed(lhs) def asFixed(proto: FixedPoint): FixedPoint = ev.asFixed(lhs, proto) - def asReal(): DspReal = ev.asReal(lhs) + def asReal: DspReal = ev.asReal(lhs) } class BinaryRepresentationOps[A <: Data](lhs: A)(implicit ev: BinaryRepresentation[A]) { - def <<(n: Int): A = ev.shl(lhs, n) + def <<(n: Int): A = ev.shl(lhs, n) def <<(n: UInt): A = ev.shl(lhs, n) - def >>(n: Int): A = ev.shr(lhs, n) + def >>(n: Int): A = ev.shr(lhs, n) def >>(n: UInt): A = ev.shr(lhs, n) - def signBit(): Bool = ev.signBit(lhs) - def div2(n: Int): A = ev.div2(lhs, n) - def mul2(n: Int): A = ev.mul2(lhs, n) + def signBit: Bool = ev.signBit(lhs) + def div2(n: Int): A = ev.div2(lhs, n) + def mul2(n: Int): A = ev.mul2(lhs, n) def trimBinary(n: Int): A = ev.trimBinary(lhs, n) } @@ -126,5 +126,5 @@ class ContextualRingOps[A <: Data](lhs: A)(implicit ev: Ring[A]) { def context_+(rhs: A): A = ev.plusContext(lhs, rhs) def context_-(rhs: A): A = ev.minusContext(lhs, rhs) def context_*(rhs: A): A = ev.timesContext(lhs, rhs) - def context_unary_-(): A = ev.negateContext(lhs) + def context_unary_- : A = ev.negateContext(lhs) } diff --git a/src/main/scala/dsptools/numbers/implicits/ImplicitSyntax.scala b/src/main/scala/dsptools/numbers/implicits/ImplicitSyntax.scala index f9d22a19..511fbf7e 100644 --- a/src/main/scala/dsptools/numbers/implicits/ImplicitSyntax.scala +++ b/src/main/scala/dsptools/numbers/implicits/ImplicitSyntax.scala @@ -7,41 +7,43 @@ import chisel3.Data import scala.language.implicitConversions trait EqSyntax { - implicit def eqOps[A <: Data:Eq](a: A): EqOps[A] = new EqOps(a) + implicit def eqOps[A <: Data: Eq](a: A): EqOps[A] = new EqOps(a) } trait PartialOrderSyntax extends EqSyntax { - implicit def partialOrderOps[A <: Data:PartialOrder](a: A): PartialOrderOps[A] = new PartialOrderOps(a) + implicit def partialOrderOps[A <: Data: PartialOrder](a: A): PartialOrderOps[A] = new PartialOrderOps(a) } trait OrderSyntax extends PartialOrderSyntax { - implicit def orderOps[A <: Data:Order](a: A): OrderOps[A] = new OrderOps(a) + implicit def orderOps[A <: Data: Order](a: A): OrderOps[A] = new OrderOps(a) } trait SignedSyntax { - implicit def signedOps[A <: Data:Signed](a: A): SignedOps[A] = new SignedOps(a) + implicit def signedOps[A <: Data: Signed](a: A): SignedOps[A] = new SignedOps(a) } trait IsRealSyntax extends OrderSyntax with SignedSyntax { - implicit def isRealOps[A <: Data:IsReal](a: A): IsRealOps[A] = new IsRealOps(a) + implicit def isRealOps[A <: Data: IsReal](a: A): IsRealOps[A] = new IsRealOps(a) } trait IsIntegerSyntax extends IsRealSyntax { - implicit def isIntegerOps[A <: Data:IsIntegral](a: A): IsIntegerOps[A] = new IsIntegerOps(a) + implicit def isIntegerOps[A <: Data: IsIntegral](a: A): IsIntegerOps[A] = new IsIntegerOps(a) } trait ConvertableToSyntax { - implicit def convertableToOps[A <: Data:ConvertableTo](a: A): ConvertableToOps[A] = new ConvertableToOps(a) + implicit def convertableToOps[A <: Data: ConvertableTo](a: A): ConvertableToOps[A] = new ConvertableToOps(a) } trait ChiselConvertableFromSyntax { - implicit def chiselConvertableFromOps[A <: Data:ChiselConvertableFrom](a: A): ChiselConvertableFromOps[A] = new ChiselConvertableFromOps(a) + implicit def chiselConvertableFromOps[A <: Data: ChiselConvertableFrom](a: A): ChiselConvertableFromOps[A] = + new ChiselConvertableFromOps(a) } trait BinaryRepresentationSyntax { - implicit def binaryRepresentationOps[A <: Data:BinaryRepresentation](a: A): BinaryRepresentationOps[A] = new BinaryRepresentationOps(a) + implicit def binaryRepresentationOps[A <: Data: BinaryRepresentation](a: A): BinaryRepresentationOps[A] = + new BinaryRepresentationOps(a) } trait ContextualRingSyntax { - implicit def contextualRingOps[A <: Data:Ring](a: A): ContextualRingOps[A] = new ContextualRingOps(a) -} \ No newline at end of file + implicit def contextualRingOps[A <: Data: Ring](a: A): ContextualRingOps[A] = new ContextualRingOps(a) +} diff --git a/src/main/scala/dsptools/numbers/implicits/ImplicitsTop.scala b/src/main/scala/dsptools/numbers/implicits/ImplicitsTop.scala index 2590ac78..d5a1f5db 100644 --- a/src/main/scala/dsptools/numbers/implicits/ImplicitsTop.scala +++ b/src/main/scala/dsptools/numbers/implicits/ImplicitsTop.scala @@ -2,10 +2,18 @@ package dsptools.numbers -trait AllSyntax extends EqSyntax with PartialOrderSyntax with OrderSyntax with SignedSyntax with IsRealSyntax with IsIntegerSyntax with - ConvertableToSyntax with ChiselConvertableFromSyntax with BinaryRepresentationSyntax with ContextualRingSyntax +trait AllSyntax + extends EqSyntax + with PartialOrderSyntax + with OrderSyntax + with SignedSyntax + with IsRealSyntax + with IsIntegerSyntax + with ConvertableToSyntax + with ChiselConvertableFromSyntax + with BinaryRepresentationSyntax + with ContextualRingSyntax trait AllImpl extends UIntImpl with SIntImpl with FixedPointImpl with DspRealImpl with DspComplexImpl -object implicits extends AllSyntax with AllImpl with spire.syntax.AllSyntax { -} +object implicits extends AllSyntax with AllImpl with spire.syntax.AllSyntax {} diff --git a/src/main/scala/dsptools/numbers/number_types/Numbers.scala b/src/main/scala/dsptools/numbers/number_types/Numbers.scala index 64dfe6eb..b66b8091 100644 --- a/src/main/scala/dsptools/numbers/number_types/Numbers.scala +++ b/src/main/scala/dsptools/numbers/number_types/Numbers.scala @@ -26,7 +26,7 @@ trait IsReal[A <: Data] extends Any with Order[A] with Signed[A] { def floor(a: A): A /** - * Rounds `a` to the nearest integer + * Rounds `a` to the nearest integer * (When the fractional part is 0.5, tie breaking rounds to positive infinity i.e. round half up) */ def round(a: A): A @@ -45,15 +45,15 @@ object IsReal { } trait IsIntegral[A <: Data] extends Any with IsReal[A] { - def ceil(a: A): A = a - def floor(a: A): A = a - def round(a: A): A = a + def ceil(a: A): A = a + def floor(a: A): A = a + def round(a: A): A = a def isWhole(a: A): Bool = true.B - + def mod(a: A, b: A): A - def isOdd(a: A): Bool - def isEven(a: A): Bool = !isOdd(a) + def isOdd(a: A): Bool + def isEven(a: A): Bool = !isOdd(a) def truncate(a: A): A = a } @@ -64,9 +64,9 @@ object IsIntegral { ///////////////////////////////////////////////////////////////////////////////////// trait Real[A <: Data] extends Any with Ring[A] with ConvertableTo[A] with IsReal[A] { - def fromRational(a: spire.math.Rational): A = fromDouble(a.toDouble) + def fromRational(a: spire.math.Rational): A = fromDouble(a.toDouble) def fromAlgebraic(a: spire.math.Algebraic): A = fromDouble(a.toDouble) - def fromReal(a: spire.math.Real): A = fromDouble(a.toDouble) + def fromReal(a: spire.math.Real): A = fromDouble(a.toDouble) } object Real { @@ -83,4 +83,4 @@ trait Integer[A <: Data] extends Any with Real[A] with IsIntegral[A] object Integer { @inline final def apply[A <: Data](implicit ev: Integer[A]): Integer[A] = ev -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/numbers/package.scala b/src/main/scala/dsptools/numbers/package.scala index 76df3456..62c53abd 100644 --- a/src/main/scala/dsptools/numbers/package.scala +++ b/src/main/scala/dsptools/numbers/package.scala @@ -2,20 +2,23 @@ package dsptools -package object numbers extends AllSyntax with AllImpl with spire.syntax.RingSyntax -/*with spire.syntax.AllSyntax*/ { - type AdditiveGroup[T] = spire.algebra.AdditiveGroup[T] - type CMonoid[T] = spire.algebra.CMonoid[T] - type ConvertableFrom[T] = spire.math.ConvertableFrom[T] - type Field[T] = spire.algebra.Field[T] +package object numbers + extends AllSyntax + with AllImpl + with spire.syntax.RingSyntax + /*with spire.syntax.AllSyntax*/ { + type AdditiveGroup[T] = spire.algebra.AdditiveGroup[T] + type CMonoid[T] = spire.algebra.CMonoid[T] + type ConvertableFrom[T] = spire.math.ConvertableFrom[T] + type Field[T] = spire.algebra.Field[T] type MultiplicativeAction[T, U] = spire.algebra.MultiplicativeAction[T, U] - type MultiplicativeCMonoid[T] = spire.algebra.MultiplicativeCMonoid[T] + type MultiplicativeCMonoid[T] = spire.algebra.MultiplicativeCMonoid[T] - val Multiplicative = spire.algebra.Multiplicative + val Multiplicative = spire.algebra.Multiplicative // rounding aliases - val Floor = RoundDown - val Ceiling = RoundUp - val Convergent = RoundHalfToEven - val Round = RoundHalfTowardsInfinity + val Floor = RoundDown + val Ceiling = RoundUp + val Convergent = RoundHalfToEven + val Round = RoundHalfTowardsInfinity } diff --git a/src/main/scala/dsptools/numbers/representations/BaseN.scala b/src/main/scala/dsptools/numbers/representations/BaseN.scala index b7fdfbd9..b265d4c2 100644 --- a/src/main/scala/dsptools/numbers/representations/BaseN.scala +++ b/src/main/scala/dsptools/numbers/representations/BaseN.scala @@ -22,6 +22,7 @@ object BaseN { // Should return non-empty list if (temp.isEmpty) Seq(0) else temp } + /** Zero pads Seq[Int] base-r representation */ def toDigitSeqMSDFirst(n: Int, r: Int, maxn: Int): Seq[Int] = { val digitSeq = toDigitSeqMSDFirst(n, r) @@ -33,4 +34,4 @@ object BaseN { /** Returns # of Base r digits needed to represent the number n */ def numDigits(n: Int, r: Int): Int = toDigitSeqInternal(n, r).length -} \ No newline at end of file +} diff --git a/src/main/scala/dsptools/numbers/rounding/Saturate.scala b/src/main/scala/dsptools/numbers/rounding/Saturate.scala deleted file mode 100644 index 2f42334c..00000000 --- a/src/main/scala/dsptools/numbers/rounding/Saturate.scala +++ /dev/null @@ -1,282 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package dsptools.numbers.rounding - -import chisel3.{fromDoubleToLiteral => _, fromIntToBinaryPoint => _, _} -import chisel3.experimental.{ChiselAnnotation, RunFirrtlTransform, annotate, requireIsHardware} -import fixedpoint._ -import chisel3.stage.ChiselStage -import firrtl.{CircuitForm, CircuitState, HighForm, MidForm, Transform} -import firrtl.annotations.{ModuleName, SingleTargetAnnotation, Target} -import firrtl.ir.{Block, DefModule, FixedType, IntWidth, SIntType, UIntType, Module => FModule} - -import scala.collection.immutable.HashMap -import scala.language.existentials - -sealed trait SaturatingOp -case object SaturatingAdd extends SaturatingOp -case object SaturatingSub extends SaturatingOp - -case class SaturateAnnotation(target: ModuleName, op: SaturatingOp, pipe: Int = 0) extends SingleTargetAnnotation[ModuleName] { - def duplicate(t: ModuleName): SaturateAnnotation = this.copy(target = t) -} - -case class SaturateChiselAnnotation(target: SaturateDummyModule[_ <: Data], op: SaturatingOp, pipe: Int = 0) extends ChiselAnnotation with RunFirrtlTransform { - def toFirrtl: SaturateAnnotation = SaturateAnnotation(target.toTarget, op = op, pipe = pipe) - def transformClass: Class[SaturateTransform] = classOf[SaturateTransform] -} - -trait SaturateModule[T <: Data] extends Module { - val a: T - val b: T - val c: T -} - -class SaturateUIntAddModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[UInt] { - require(pipe == 0, "pipe not implemented yet") - - val a = IO(Input(UInt(aWidth.W))) - val b = IO(Input(UInt(bWidth.W))) - val c = IO(Output(UInt(cWidth.W))) - - val max = ((1 << cWidth) - 1).U - val sumWithGrow = a +& b - val tooBig = sumWithGrow(cWidth) - val sum = sumWithGrow(cWidth - 1, 0) - - c := Mux(tooBig, max, sum) -} - -class SaturateUIntSubModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[UInt] { - require(pipe == 0, "pipe not implemented yet") - val a = IO(Input(UInt(aWidth.W))) - val b = IO(Input(UInt(bWidth.W))) - val c = IO(Output(UInt(cWidth.W))) - - val tooSmall = a < b - val diff = a -% b - - c := Mux(tooSmall, 0.U, diff) -} - -class SaturateSIntAddModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[SInt] { - require(pipe == 0, "pipe not implemented yet") - val a = IO(Input(SInt(aWidth.W))) - val b = IO(Input(SInt(bWidth.W))) - val c = IO(Output(SInt(cWidth.W))) - - val abWidth = aWidth max bWidth - val max = ((1 << (cWidth - 1)) - 1).S - val min = (-(1 << (cWidth - 1))).S - val sumWithGrow = a +& b - - val tooBig = !sumWithGrow(abWidth) && sumWithGrow(abWidth - 1) - val tooSmall = sumWithGrow(abWidth) && !sumWithGrow(abWidth - 1) - - val sum = sumWithGrow(abWidth - 1, 0).asSInt - val fixTop = Mux(tooBig, max, sum) - val fixTopAndBottom = Mux(tooSmall, min, fixTop) - - c := fixTopAndBottom -} - -class SaturateSIntSubModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[SInt] { - require(pipe == 0, "pipe not implemented yet") - val a = IO(Input(SInt(aWidth.W))) - val b = IO(Input(SInt(bWidth.W))) - val c = IO(Output(SInt(cWidth.W))) - - val abWidth = aWidth max bWidth - val max = ((1 << (cWidth - 1)) - 1).S - val min = (-(1 << (cWidth - 1))).S - val sumWithGrow = a -& b - - val tooBig = !sumWithGrow(abWidth) && sumWithGrow(abWidth - 1) - val tooSmall = sumWithGrow(abWidth) && !sumWithGrow(abWidth - 1) - - val sum = sumWithGrow(cWidth - 1, 0).asSInt - val fixTop = Mux(tooBig, max, sum) - val fixTopAndBottom = Mux(tooSmall, min, fixTop) - - c := fixTopAndBottom -} - -class SaturateFixedPointAddModule( - aWidth: Int, aBP: Int, - bWidth: Int, bBP: Int, - cWidth: Int, cBP: Int, - pipe: Int) extends SaturateModule[FixedPoint] { - require(pipe == 0, "pipe not implemented yet") - - val a = IO(Input(FixedPoint(aWidth.W, aBP.BP))) - val b = IO(Input(FixedPoint(bWidth.W, bBP.BP))) - val c = IO(Output(FixedPoint(cWidth.W, cBP.BP))) - - - val max = (math.pow(2, (cWidth - cBP - 1)) - math.pow(2, -cBP)).F(cWidth.W, cBP.BP) - val min = (-math.pow(2, (cWidth - cBP - 1))).F(cWidth.W, cBP.BP) - val sumWithGrow = a +& b - - val tooBig = !sumWithGrow(cWidth) && sumWithGrow(cWidth - 1) - val tooSmall = sumWithGrow(cWidth) && !sumWithGrow(cWidth - 1) - - val sum = sumWithGrow(cWidth - 1, 0).asFixedPoint(cBP.BP) - val fixTop = Mux(tooBig, max, sum) - val fixTopAndBottom = Mux(tooSmall, min, fixTop) - - c := fixTopAndBottom -} - -class SaturateFixedPointSubModule( - aWidth: Int, aBP: Int, - bWidth: Int, bBP: Int, - cWidth: Int, cBP: Int, - pipe: Int) extends SaturateModule[FixedPoint] { - require(pipe == 0, "pipe not implemented yet") - - val a = IO(Input(FixedPoint(aWidth.W, aBP.BP))) - val b = IO(Input(FixedPoint(bWidth.W, bBP.BP))) - val c = IO(Output(FixedPoint(cWidth.W, cBP.BP))) - - val max = (math.pow(2, (cWidth - cBP - 1)) - math.pow(2, -cBP)).F(cWidth.W, cBP.BP) - val min = (-math.pow(2, (cWidth - cBP - 1))).F(cWidth.W, cBP.BP) - val diffWithGrow = a -& b - - val tooBig = !diffWithGrow(cWidth) && diffWithGrow(cWidth - 1) - val tooSmall = diffWithGrow(cWidth) && !diffWithGrow(cWidth - 1) - - val diff = diffWithGrow(cWidth - 1, 0).asFixedPoint(cBP.BP) - val fixTop = Mux(tooBig, max, diff) - val fixTopAndBottom = Mux(tooSmall, min, fixTop) - - c := fixTopAndBottom -} - -/** - * A module that serves as a placeholder for a saturating op. - * The frontend can't implement saturation easily when widths are unknown. This - * module inserts a dummy op that has the desired behavior in FIRRTL's width - * inference process. After width inference, this module will be replaced by an - * implementation of a saturating op. - */ -class SaturateDummyModule[T <: Data](aOutside: T, bOutside: T, op: (T, T) => T) extends SaturateModule[T] { - // this module should always be replaced in a transform - // throw in this assertion in case it isn't - assert(false.B) - val a = IO(Input(chiselTypeOf(aOutside))) - val b = IO(Input(chiselTypeOf(bOutside))) - val res = op(a, b) - val c = IO(Output(chiselTypeOf(res))) - c := res -} - -object Saturate { - private def op[T <: Data](a: T, b: T, widthOp: (T, T) => T, realOp: SaturatingOp, pipe: Int = 0): T = { - requireIsHardware(a) - requireIsHardware(b) - val saturate = Module(new SaturateDummyModule(a, b, widthOp)) - val anno = SaturateChiselAnnotation(saturate, realOp, pipe) - annotate(anno) - saturate.a := a - saturate.b := b - saturate.c - } - def addUInt(a: UInt, b: UInt, pipe: Int = 0): UInt = { - op(a, b, { (l: UInt, r: UInt) => l +% r }, SaturatingAdd, pipe) - } - def addSInt(a: SInt, b: SInt, pipe: Int = 0): SInt = { - op(a, b, { (l: SInt, r: SInt) => l +% r }, SaturatingAdd, pipe) - } - def addFixedPoint(a: FixedPoint, b: FixedPoint, pipe: Int = 0): FixedPoint = { - op(a, b, { (l: FixedPoint, r: FixedPoint) => (l +& r) >> 1 }, SaturatingAdd, pipe) - } - def subUInt(a: UInt, b: UInt, pipe: Int = 0): UInt = { - op(a, b, { (l: UInt, r: UInt) => l -% r }, SaturatingSub, pipe) - } - def subSInt(a: SInt, b: SInt, pipe: Int = 0): SInt = { - op(a, b, { (l: SInt, r: SInt) => l -% r }, SaturatingSub, pipe) - } - def subFixedPoint(a: FixedPoint, b: FixedPoint, pipe: Int = 0): FixedPoint = { - op(a, b, { (l: FixedPoint, r: FixedPoint) => (l -& r) >> 1 }, SaturatingSub, pipe) - } -} - -class SaturateTransform extends Transform { - def inputForm: CircuitForm = MidForm - def outputForm: CircuitForm = HighForm - - private def replaceMod(m: FModule, anno: SaturateAnnotation): FModule = { - val aTpe = m.ports.find(_.name == "a").map(_.tpe).getOrElse(throw new Exception("a not found")) - val bTpe = m.ports.find(_.name == "b").map(_.tpe).getOrElse(throw new Exception("b not found")) - val cTpe = m.ports.find(_.name == "c").map(_.tpe).getOrElse(throw new Exception("c not found")) - - val newMod = (aTpe, bTpe, cTpe, anno) match { - case ( - UIntType(IntWidth(aWidth)), - UIntType(IntWidth(bWidth)), - UIntType(IntWidth(cWidth)), - SaturateAnnotation(_, SaturatingAdd, pipe)) => - () => new SaturateUIntAddModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe) - case ( - UIntType(IntWidth(aWidth)), - UIntType(IntWidth(bWidth)), - UIntType(IntWidth(cWidth)), - SaturateAnnotation(_, SaturatingSub, pipe)) => - () => new SaturateUIntSubModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe) - case ( - SIntType(IntWidth(aWidth)), - SIntType(IntWidth(bWidth)), - SIntType(IntWidth(cWidth)), - SaturateAnnotation(_, SaturatingAdd, pipe)) => - () => new SaturateSIntAddModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe) - case ( - SIntType(IntWidth(aWidth)), - SIntType(IntWidth(bWidth)), - SIntType(IntWidth(cWidth)), - SaturateAnnotation(_, SaturatingSub, pipe)) => - () => new SaturateSIntSubModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe) - case ( - FixedType(IntWidth(aWidth), IntWidth(aBP)), - FixedType(IntWidth(bWidth), IntWidth(bBP)), - FixedType(IntWidth(cWidth), IntWidth(cBP)), - SaturateAnnotation(_, SaturatingAdd, pipe)) => - () => new SaturateFixedPointAddModule(aWidth.toInt, aBP.toInt, bWidth.toInt, bBP.toInt, (cWidth - 1).toInt, cBP.toInt, pipe = pipe) - case ( - FixedType(IntWidth(aWidth), IntWidth(aBP)), - FixedType(IntWidth(bWidth), IntWidth(bBP)), - FixedType(IntWidth(cWidth), IntWidth(cBP)), - SaturateAnnotation(_, SaturatingSub, pipe)) => - () => new SaturateFixedPointSubModule(aWidth.toInt, aBP.toInt, bWidth.toInt, bBP.toInt, (cWidth - 1).toInt, cBP.toInt, pipe = pipe) - } - // get new body from newMod (must be single module!) - - val newBody = ChiselStage.convert(newMod()).modules.head match { - case FModule(_, _, _, body) => body - case _ => throw new Exception("Saw blackbox for some reason") - } - m.copy(body = newBody) - } - - private def onModule(annos: Seq[SaturateAnnotation]) = { - val annoByName: HashMap[String, SaturateAnnotation] = HashMap(annos.map({ a => a.target.name -> a }): _*) - object SaturateAnnotation { - def unapply(name: String): Option[SaturateAnnotation] = { - annoByName.get(name) - } - } - def onModuleInner(m: DefModule): DefModule = m match { - case m@FModule(_, SaturateAnnotation(a), _, _) => - replaceMod(m, a) - case m => m - } - onModuleInner(_) - } - - def execute(state: CircuitState): CircuitState = { - val annos = state.annotations.collect { - case a: SaturateAnnotation => a - } - state.copy(circuit = state.circuit.copy(modules = - state.circuit.modules.map(onModule(annos)))) - } -} diff --git a/src/main/scala/examples/StreamingAutocorrelator.scala b/src/main/scala/examples/StreamingAutocorrelator.scala index 9faf994d..8f9b357f 100644 --- a/src/main/scala/examples/StreamingAutocorrelator.scala +++ b/src/main/scala/examples/StreamingAutocorrelator.scala @@ -7,8 +7,9 @@ import dsptools.{hasContext, DspContext, Grow} import dsptools.examples.TransposedStreamingFIR import spire.algebra.Ring -class StreamingAutocorrelator[T <: Data:Ring](inputGenerator: => T, outputGenerator: => T, delay: Int, windowSize: Int) - extends Module with hasContext { +class StreamingAutocorrelator[T <: Data: Ring](inputGenerator: => T, outputGenerator: => T, delay: Int, windowSize: Int) + extends Module + with hasContext { // implicit val ev2 = ev(context) val io = IO(new Bundle { val input = Input(inputGenerator) @@ -16,10 +17,11 @@ class StreamingAutocorrelator[T <: Data:Ring](inputGenerator: => T, outputGenera }) // create a sequence of registers (head is io.input) - val delays = (0 until delay + windowSize).scanLeft(io.input) { case (left, _) => - val nextReg = Reg(inputGenerator) - nextReg := left - nextReg + val delays = (0 until delay + windowSize).scanLeft(io.input) { + case (left, _) => + val nextReg = Reg(inputGenerator) + nextReg := left + nextReg } val window = delays.drop(delay + 1).reverse diff --git a/src/main/scala/examples/TransposedStreamingFIR.scala b/src/main/scala/examples/TransposedStreamingFIR.scala index 09e72a76..ff21b568 100644 --- a/src/main/scala/examples/TransposedStreamingFIR.scala +++ b/src/main/scala/examples/TransposedStreamingFIR.scala @@ -14,26 +14,26 @@ import spire.math.{ConvertableFrom, ConvertableTo} // This style preferred: // class CTTSF[T<:Data:Ring,V](i: T, o: T, val taps: Seq[V], conv: V=>T) -class ConstantTapTransposedStreamingFIR[T <: Data:Ring:ConvertableTo, V:ConvertableFrom]( - inputGenerator: T, - outputGenerator: T, - val taps: Seq[V]) - extends Module { +class ConstantTapTransposedStreamingFIR[T <: Data: Ring: ConvertableTo, V: ConvertableFrom]( + inputGenerator: T, + outputGenerator: T, + val taps: Seq[V]) + extends Module { val io = IO(new Bundle { - val input = Input(Valid(inputGenerator)) + val input = Input(Valid(inputGenerator)) val output = Output(Valid(outputGenerator)) }) val products: Seq[T] = taps.reverse.map { tap => - val t : T = implicitly[ConvertableTo[T]].fromType(tap) + val t: T = implicitly[ConvertableTo[T]].fromType(tap) io.input.bits * t } val last = Reg[T](outputGenerator) val nextLast = products.reduceLeft { (left: T, right: T) => val reg = Reg(left.cloneType) - when (io.input.valid) { + when(io.input.valid) { reg := left } reg + right @@ -46,13 +46,16 @@ class ConstantTapTransposedStreamingFIR[T <: Data:Ring:ConvertableTo, V:Converta io.output.valid := RegNext(io.input.valid) } -class TransposedStreamingFIR[T <: Data:Ring](inputGenerator: => T, outputGenerator: => T, - tapGenerator: => T, numberOfTaps: Int) - extends Module { +class TransposedStreamingFIR[T <: Data: Ring]( + inputGenerator: => T, + outputGenerator: => T, + tapGenerator: => T, + numberOfTaps: Int) + extends Module { val io = IO(new Bundle { - val input = Input(inputGenerator) // note, using as Input here, causes IntelliJ to not like '*' + val input = Input(inputGenerator) // note, using as Input here, causes IntelliJ to not like '*' val output = Output(outputGenerator) - val taps = Input(Vec(numberOfTaps, tapGenerator)) // note, using as Input here, causes IntelliJ to not like '*' + val taps = Input(Vec(numberOfTaps, tapGenerator)) // note, using as Input here, causes IntelliJ to not like '*' }) val products: Seq[T] = io.taps.reverse.map { tap: T => diff --git a/src/main/scala/examples/gainOffCorr.scala b/src/main/scala/examples/gainOffCorr.scala index f06c0b7e..42c27247 100644 --- a/src/main/scala/examples/gainOffCorr.scala +++ b/src/main/scala/examples/gainOffCorr.scala @@ -13,15 +13,15 @@ import spire.implicits._ // 3.Assuming the number of input sources = number of lanes for now // 4.Assuming that the memory interface for gain and offset values will be done at a higher level -class gainOffCorr[T<:Data:Ring](genIn: => T,genGain: => T,genOff: => T,genOut: => T, numLanes: Int) extends Module { - val io = IO(new Bundle { - val inputVal = Input(Vec(numLanes, genIn)) - val gainCorr = Input(Vec(numLanes, genGain)) - val offsetCorr = Input(Vec(numLanes, genOff)) - val outputVal = Output(Vec(numLanes, genOut)) - }) +class gainOffCorr[T <: Data: Ring](genIn: => T, genGain: => T, genOff: => T, genOut: => T, numLanes: Int) + extends Module { + val io = IO(new Bundle { + val inputVal = Input(Vec(numLanes, genIn)) + val gainCorr = Input(Vec(numLanes, genGain)) + val offsetCorr = Input(Vec(numLanes, genOff)) + val outputVal = Output(Vec(numLanes, genOut)) + }) - val inputGainCorr = io.inputVal.zip(io.gainCorr).map{case (in, gain) => in*gain } - io.outputVal := inputGainCorr.zip(io.offsetCorr).map{case (inGainCorr, offset) => inGainCorr + offset } + val inputGainCorr = io.inputVal.zip(io.gainCorr).map { case (in, gain) => in * gain } + io.outputVal := inputGainCorr.zip(io.offsetCorr).map { case (inGainCorr, offset) => inGainCorr + offset } } -