|
1 | 1 | package types |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "fmt" |
4 | 5 | "net/url" |
| 6 | + "reflect" |
5 | 7 | "testing" |
| 8 | + "time" |
6 | 9 |
|
7 | 10 | "github.com/lib/pq" |
8 | 11 | "github.com/stretchr/testify/require" |
@@ -92,3 +95,96 @@ func TestArray(t *testing.T) { |
92 | 95 | require.Nil(err) |
93 | 96 | require.Equal(input, v) |
94 | 97 | } |
| 98 | + |
| 99 | +func TestNullable(t *testing.T) { |
| 100 | + var ( |
| 101 | + Str string |
| 102 | + Int64 int64 |
| 103 | + Float64 float64 |
| 104 | + Bool bool |
| 105 | + Time time.Time |
| 106 | + Duration time.Duration |
| 107 | + ) |
| 108 | + tim := time.Now().UTC() |
| 109 | + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), 0, tim.Location()) |
| 110 | + s := require.New(t) |
| 111 | + cases := []struct { |
| 112 | + name string |
| 113 | + typ string |
| 114 | + nonNullInput interface{} |
| 115 | + dst interface{} |
| 116 | + }{ |
| 117 | + { |
| 118 | + "string", |
| 119 | + "text", |
| 120 | + "foo", |
| 121 | + &Str, |
| 122 | + }, |
| 123 | + { |
| 124 | + "int64", |
| 125 | + "bigint", |
| 126 | + int64(1), |
| 127 | + &Int64, |
| 128 | + }, |
| 129 | + { |
| 130 | + "float64", |
| 131 | + "decimal", |
| 132 | + float64(.5), |
| 133 | + &Float64, |
| 134 | + }, |
| 135 | + { |
| 136 | + "bool", |
| 137 | + "bool", |
| 138 | + true, |
| 139 | + &Bool, |
| 140 | + }, |
| 141 | + { |
| 142 | + "time.Duration", |
| 143 | + "bigint", |
| 144 | + 3 * time.Second, |
| 145 | + &Duration, |
| 146 | + }, |
| 147 | + { |
| 148 | + "time.Time", |
| 149 | + "timestamptz", |
| 150 | + tim, |
| 151 | + &Time, |
| 152 | + }, |
| 153 | + } |
| 154 | + |
| 155 | + db, err := openTestDB() |
| 156 | + s.Nil(err) |
| 157 | + |
| 158 | + defer func() { |
| 159 | + _, err = db.Exec("DROP TABLE IF EXISTS foo") |
| 160 | + s.Nil(err) |
| 161 | + s.Nil(db.Close()) |
| 162 | + }() |
| 163 | + |
| 164 | + for _, c := range cases { |
| 165 | + s.Nil(db.QueryRow("SELECT null").Scan(Nullable(c.dst)), c.name) |
| 166 | + elem := reflect.ValueOf(c.dst).Elem() |
| 167 | + zero := reflect.Zero(elem.Type()) |
| 168 | + s.Equal(zero.Interface(), elem.Interface(), c.name) |
| 169 | + |
| 170 | + var input = c.nonNullInput |
| 171 | + if v, ok := c.nonNullInput.(time.Duration); ok { |
| 172 | + input = int64(v) |
| 173 | + } |
| 174 | + |
| 175 | + _, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo ( |
| 176 | + testcol %s |
| 177 | + )`, c.typ)) |
| 178 | + s.Nil(err, c.name) |
| 179 | + |
| 180 | + _, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", input) |
| 181 | + s.Nil(err, c.name) |
| 182 | + |
| 183 | + s.Nil(db.QueryRow("SELECT testcol FROM foo").Scan(Nullable(c.dst)), c.name) |
| 184 | + elem = reflect.ValueOf(c.dst).Elem() |
| 185 | + s.Equal(c.nonNullInput, elem.Interface(), c.name) |
| 186 | + |
| 187 | + _, err = db.Exec("DROP TABLE foo") |
| 188 | + s.Nil(err, c.name) |
| 189 | + } |
| 190 | +} |
0 commit comments