-- | Test values
--
-- Intended for unqualified import.
module Test.Tensor.TestValue (
    TestValue -- opaque
  ) where

import Data.List (sort)
import System.Random (Random)
import Test.QuickCheck
import Text.Printf (printf)

{-------------------------------------------------------------------------------
  Definition
-------------------------------------------------------------------------------}

-- | Test values
--
-- Test values are suitable for use in QuickCheck tests involving floating
-- point numbers, if you want to ignore rounding errors.
newtype TestValue = TestValue Float
  deriving newtype (
      Int -> TestValue
TestValue -> Int
TestValue -> [TestValue]
TestValue -> TestValue
TestValue -> TestValue -> [TestValue]
TestValue -> TestValue -> TestValue -> [TestValue]
(TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (Int -> TestValue)
-> (TestValue -> Int)
-> (TestValue -> [TestValue])
-> (TestValue -> TestValue -> [TestValue])
-> (TestValue -> TestValue -> [TestValue])
-> (TestValue -> TestValue -> TestValue -> [TestValue])
-> Enum TestValue
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: TestValue -> TestValue
succ :: TestValue -> TestValue
$cpred :: TestValue -> TestValue
pred :: TestValue -> TestValue
$ctoEnum :: Int -> TestValue
toEnum :: Int -> TestValue
$cfromEnum :: TestValue -> Int
fromEnum :: TestValue -> Int
$cenumFrom :: TestValue -> [TestValue]
enumFrom :: TestValue -> [TestValue]
$cenumFromThen :: TestValue -> TestValue -> [TestValue]
enumFromThen :: TestValue -> TestValue -> [TestValue]
$cenumFromTo :: TestValue -> TestValue -> [TestValue]
enumFromTo :: TestValue -> TestValue -> [TestValue]
$cenumFromThenTo :: TestValue -> TestValue -> TestValue -> [TestValue]
enumFromThenTo :: TestValue -> TestValue -> TestValue -> [TestValue]
Enum
    , Fractional TestValue
TestValue
Fractional TestValue =>
TestValue
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> Floating TestValue
TestValue -> TestValue
TestValue -> TestValue -> TestValue
forall a.
Fractional a =>
a
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> Floating a
$cpi :: TestValue
pi :: TestValue
$cexp :: TestValue -> TestValue
exp :: TestValue -> TestValue
$clog :: TestValue -> TestValue
log :: TestValue -> TestValue
$csqrt :: TestValue -> TestValue
sqrt :: TestValue -> TestValue
$c** :: TestValue -> TestValue -> TestValue
** :: TestValue -> TestValue -> TestValue
$clogBase :: TestValue -> TestValue -> TestValue
logBase :: TestValue -> TestValue -> TestValue
$csin :: TestValue -> TestValue
sin :: TestValue -> TestValue
$ccos :: TestValue -> TestValue
cos :: TestValue -> TestValue
$ctan :: TestValue -> TestValue
tan :: TestValue -> TestValue
$casin :: TestValue -> TestValue
asin :: TestValue -> TestValue
$cacos :: TestValue -> TestValue
acos :: TestValue -> TestValue
$catan :: TestValue -> TestValue
atan :: TestValue -> TestValue
$csinh :: TestValue -> TestValue
sinh :: TestValue -> TestValue
$ccosh :: TestValue -> TestValue
cosh :: TestValue -> TestValue
$ctanh :: TestValue -> TestValue
tanh :: TestValue -> TestValue
$casinh :: TestValue -> TestValue
asinh :: TestValue -> TestValue
$cacosh :: TestValue -> TestValue
acosh :: TestValue -> TestValue
$catanh :: TestValue -> TestValue
atanh :: TestValue -> TestValue
$clog1p :: TestValue -> TestValue
log1p :: TestValue -> TestValue
$cexpm1 :: TestValue -> TestValue
expm1 :: TestValue -> TestValue
$clog1pexp :: TestValue -> TestValue
log1pexp :: TestValue -> TestValue
$clog1mexp :: TestValue -> TestValue
log1mexp :: TestValue -> TestValue
Floating
    , Num TestValue
Num TestValue =>
(TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (Rational -> TestValue)
-> Fractional TestValue
Rational -> TestValue
TestValue -> TestValue
TestValue -> TestValue -> TestValue
forall a.
Num a =>
(a -> a -> a) -> (a -> a) -> (Rational -> a) -> Fractional a
$c/ :: TestValue -> TestValue -> TestValue
/ :: TestValue -> TestValue -> TestValue
$crecip :: TestValue -> TestValue
recip :: TestValue -> TestValue
$cfromRational :: Rational -> TestValue
fromRational :: Rational -> TestValue
Fractional
    , Integer -> TestValue
TestValue -> TestValue
TestValue -> TestValue -> TestValue
(TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (Integer -> TestValue)
-> Num TestValue
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: TestValue -> TestValue -> TestValue
+ :: TestValue -> TestValue -> TestValue
$c- :: TestValue -> TestValue -> TestValue
- :: TestValue -> TestValue -> TestValue
$c* :: TestValue -> TestValue -> TestValue
* :: TestValue -> TestValue -> TestValue
$cnegate :: TestValue -> TestValue
negate :: TestValue -> TestValue
$cabs :: TestValue -> TestValue
abs :: TestValue -> TestValue
$csignum :: TestValue -> TestValue
signum :: TestValue -> TestValue
$cfromInteger :: Integer -> TestValue
fromInteger :: Integer -> TestValue
Num
    , (forall g.
 RandomGen g =>
 (TestValue, TestValue) -> g -> (TestValue, g))
-> (forall g. RandomGen g => g -> (TestValue, g))
-> (forall g.
    RandomGen g =>
    (TestValue, TestValue) -> g -> [TestValue])
-> (forall g. RandomGen g => g -> [TestValue])
-> Random TestValue
forall g. RandomGen g => g -> [TestValue]
forall g. RandomGen g => g -> (TestValue, g)
forall g. RandomGen g => (TestValue, TestValue) -> g -> [TestValue]
forall g.
RandomGen g =>
(TestValue, TestValue) -> g -> (TestValue, g)
forall a.
(forall g. RandomGen g => (a, a) -> g -> (a, g))
-> (forall g. RandomGen g => g -> (a, g))
-> (forall g. RandomGen g => (a, a) -> g -> [a])
-> (forall g. RandomGen g => g -> [a])
-> Random a
$crandomR :: forall g.
RandomGen g =>
(TestValue, TestValue) -> g -> (TestValue, g)
randomR :: forall g.
RandomGen g =>
(TestValue, TestValue) -> g -> (TestValue, g)
$crandom :: forall g. RandomGen g => g -> (TestValue, g)
random :: forall g. RandomGen g => g -> (TestValue, g)
$crandomRs :: forall g. RandomGen g => (TestValue, TestValue) -> g -> [TestValue]
randomRs :: forall g. RandomGen g => (TestValue, TestValue) -> g -> [TestValue]
$crandoms :: forall g. RandomGen g => g -> [TestValue]
randoms :: forall g. RandomGen g => g -> [TestValue]
Random
    , ReadPrec [TestValue]
ReadPrec TestValue
Int -> ReadS TestValue
ReadS [TestValue]
(Int -> ReadS TestValue)
-> ReadS [TestValue]
-> ReadPrec TestValue
-> ReadPrec [TestValue]
-> Read TestValue
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS TestValue
readsPrec :: Int -> ReadS TestValue
$creadList :: ReadS [TestValue]
readList :: ReadS [TestValue]
$creadPrec :: ReadPrec TestValue
readPrec :: ReadPrec TestValue
$creadListPrec :: ReadPrec [TestValue]
readListPrec :: ReadPrec [TestValue]
Read
    , Num TestValue
Ord TestValue
(Num TestValue, Ord TestValue) =>
(TestValue -> Rational) -> Real TestValue
TestValue -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
$ctoRational :: TestValue -> Rational
toRational :: TestValue -> Rational
Real
    , Floating TestValue
RealFrac TestValue
(RealFrac TestValue, Floating TestValue) =>
(TestValue -> Integer)
-> (TestValue -> Int)
-> (TestValue -> (Int, Int))
-> (TestValue -> (Integer, Int))
-> (Integer -> Int -> TestValue)
-> (TestValue -> Int)
-> (TestValue -> TestValue)
-> (Int -> TestValue -> TestValue)
-> (TestValue -> Bool)
-> (TestValue -> Bool)
-> (TestValue -> Bool)
-> (TestValue -> Bool)
-> (TestValue -> Bool)
-> (TestValue -> TestValue -> TestValue)
-> RealFloat TestValue
Int -> TestValue -> TestValue
Integer -> Int -> TestValue
TestValue -> Bool
TestValue -> Int
TestValue -> Integer
TestValue -> (Int, Int)
TestValue -> (Integer, Int)
TestValue -> TestValue
TestValue -> TestValue -> TestValue
forall a.
(RealFrac a, Floating a) =>
(a -> Integer)
-> (a -> Int)
-> (a -> (Int, Int))
-> (a -> (Integer, Int))
-> (Integer -> Int -> a)
-> (a -> Int)
-> (a -> a)
-> (Int -> a -> a)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> Bool)
-> (a -> a -> a)
-> RealFloat a
$cfloatRadix :: TestValue -> Integer
floatRadix :: TestValue -> Integer
$cfloatDigits :: TestValue -> Int
floatDigits :: TestValue -> Int
$cfloatRange :: TestValue -> (Int, Int)
floatRange :: TestValue -> (Int, Int)
$cdecodeFloat :: TestValue -> (Integer, Int)
decodeFloat :: TestValue -> (Integer, Int)
$cencodeFloat :: Integer -> Int -> TestValue
encodeFloat :: Integer -> Int -> TestValue
$cexponent :: TestValue -> Int
exponent :: TestValue -> Int
$csignificand :: TestValue -> TestValue
significand :: TestValue -> TestValue
$cscaleFloat :: Int -> TestValue -> TestValue
scaleFloat :: Int -> TestValue -> TestValue
$cisNaN :: TestValue -> Bool
isNaN :: TestValue -> Bool
$cisInfinite :: TestValue -> Bool
isInfinite :: TestValue -> Bool
$cisDenormalized :: TestValue -> Bool
isDenormalized :: TestValue -> Bool
$cisNegativeZero :: TestValue -> Bool
isNegativeZero :: TestValue -> Bool
$cisIEEE :: TestValue -> Bool
isIEEE :: TestValue -> Bool
$catan2 :: TestValue -> TestValue -> TestValue
atan2 :: TestValue -> TestValue -> TestValue
RealFloat
    , Fractional TestValue
Real TestValue
(Real TestValue, Fractional TestValue) =>
(forall b. Integral b => TestValue -> (b, TestValue))
-> (forall b. Integral b => TestValue -> b)
-> (forall b. Integral b => TestValue -> b)
-> (forall b. Integral b => TestValue -> b)
-> (forall b. Integral b => TestValue -> b)
-> RealFrac TestValue
forall b. Integral b => TestValue -> b
forall b. Integral b => TestValue -> (b, TestValue)
forall a.
(Real a, Fractional a) =>
(forall b. Integral b => a -> (b, a))
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> (forall b. Integral b => a -> b)
-> RealFrac a
$cproperFraction :: forall b. Integral b => TestValue -> (b, TestValue)
properFraction :: forall b. Integral b => TestValue -> (b, TestValue)
$ctruncate :: forall b. Integral b => TestValue -> b
truncate :: forall b. Integral b => TestValue -> b
$cround :: forall b. Integral b => TestValue -> b
round :: forall b. Integral b => TestValue -> b
$cceiling :: forall b. Integral b => TestValue -> b
ceiling :: forall b. Integral b => TestValue -> b
$cfloor :: forall b. Integral b => TestValue -> b
floor :: forall b. Integral b => TestValue -> b
RealFrac
    )

-- | Test values are equipped with a crude equality
--
-- >               (==)
-- > --------------------
-- > 1.0    1.1    False
-- > 1.00   1.01   True
-- > 10     11     False
-- > 10.0   10.1   True
-- > 100    110    False
-- > 100    101    True
instance Eq TestValue where
  TestValue Float
x == :: TestValue -> TestValue -> Bool
== TestValue Float
y = Float -> Float -> Bool
nearlyEqual Float
x Float
y

-- | Show instance
--
-- We have more precision available for smaller values, so we show more
-- decimals. However, larger values the show instance does not reflect the
-- precision: @1000@ and @1001@ are shown as @1000@ and @1001@, even though
-- they are considered to be equal.
--
-- > show @TestValue 0     == "0"     -- True zero
-- > show @TestValue 1     == "1"     -- True one
-- > show @TestValue 0.001 == "0.00"
-- > show @TestValue 0.009 == "0.01"
-- > show @TestValue 1.001 == "1.0"
-- > show @TestValue 11    == "11"
instance Show TestValue where
  show :: TestValue -> String
show (TestValue Float
x)
    | Float
x Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
0    = String
"0"
    | Float
x Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
1    = String
"1"
    | Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
<  Float
1    = String -> Float -> String
forall r. PrintfType r => String -> r
printf String
"%0.2f" Float
x
    | Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
<  Float
10   = String -> Float -> String
forall r. PrintfType r => String -> r
printf String
"%0.1f" Float
x
    | Bool
otherwise = String -> Float -> String
forall r. PrintfType r => String -> r
printf String
"%0.0f" Float
x

-- | Arbitrary instance
--
-- The definition of 'arbitrary' simply piggy-backs on the definition for
-- 'Float', but in shrinking we avoid generating nearly equal values, and prefer
-- values closer to integral values. Compare:
--
-- >    shrink @TestValue 100.1
-- > == [0,50,75,88,94,97]
--
-- versus
--
-- >    shrink @Float 100.1
-- > == [100.0,0.0,50.0,75.0,88.0,94.0,97.0,99.0,0.0,50.1,75.1,87.6,93.9,97.0,98.6,99.4,99.8,100.0]
instance Arbitrary TestValue where
  arbitrary :: Gen TestValue
arbitrary = Float -> TestValue
TestValue (Float -> TestValue) -> Gen Float -> Gen TestValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Float
forall a. Arbitrary a => Gen a
arbitrary

  shrink :: TestValue -> [TestValue]
shrink (TestValue Float
x)
    | Float
x Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
0          = []
    | Float -> Float -> Bool
nearlyEqual Float
x Float
0 = [TestValue
0]
    | Bool
otherwise       = case [Float] -> [Float]
forall a. Ord a => [a] -> [a]
sort (Float -> [Float]
forall a. Arbitrary a => a -> [a]
shrink Float
x) of
                          []   -> []
                          Float
y:[Float]
ys -> Float -> [Float] -> [TestValue]
aux Float
y [Float]
ys
    where
      aux :: Float -> [Float] -> [TestValue]
      aux :: Float -> [Float] -> [TestValue]
aux Float
y []
        | Float -> Float -> Bool
nearlyEqual Float
y Float
x = []
        | Bool
otherwise       = [Float -> TestValue
TestValue Float
y]
      aux Float
y (Float
z:[Float]
zs)
        | Float -> Float -> Bool
nearlyEqual Float
y Float
z = if Float -> Float
decimalPart Float
y Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
< Float -> Float
decimalPart Float
z
                              then Float -> [Float] -> [TestValue]
aux Float
y [Float]
zs
                              else Float -> [Float] -> [TestValue]
aux Float
z [Float]
zs
        | Bool
otherwise       = Float -> TestValue
TestValue Float
y TestValue -> [TestValue] -> [TestValue]
forall a. a -> [a] -> [a]
: Float -> [Float] -> [TestValue]
aux Float
z [Float]
zs

instance Ord TestValue where
  compare :: TestValue -> TestValue -> Ordering
compare (TestValue Float
x) (TestValue Float
y)
    | Float -> Float -> Bool
nearlyEqual Float
x Float
y = Ordering
EQ
    | Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
< Float
y           = Ordering
LT
    | Bool
otherwise       = Ordering
GT

{-------------------------------------------------------------------------------
  Internal auxiliary
-------------------------------------------------------------------------------}

-- | Compare for near equality
--
-- Adapted from <https://stackoverflow.com/a/32334103/742991>
nearlyEqual :: Float -> Float -> Bool
nearlyEqual :: Float -> Float -> Bool
nearlyEqual Float
a Float
b
  | Float
a Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
b    = Bool
True
  | Bool
otherwise = Float
diff Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
< Float -> Float -> Float
forall a. Ord a => a -> a -> a
max Float
abs_th (Float
epsilon Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
norm)
  where
    diff, norm :: Float
    diff :: Float
diff = Float -> Float
forall a. Num a => a -> a
abs (Float
a Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
b)
    norm :: Float
norm = Float -> Float
forall a. Num a => a -> a
abs Float
a Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float -> Float
forall a. Num a => a -> a
abs Float
b

    -- Define precision
    abs_th, epsilon :: Float
    epsilon :: Float
epsilon = Float
0.01
    abs_th :: Float
abs_th  = Float
0.01

decimalPart :: Float -> Float
decimalPart :: Float -> Float
decimalPart Float
x = Float
x Float -> Float -> Float
forall a. Num a => a -> a -> a
- Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Float -> Int
forall b. Integral b => Float -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor Float
x :: Int)