99#if  HAVE_SQLITE3 
1010  #include  <sqlite3.h> 
1111
12+ /* Constants for SQL processing limits */ 
13+ #define  BUFFER_SAFETY_MARGIN  64           /* Extra space for string operations */ 
14+ #define  MAX_PAREN_SEARCH_DISTANCE  1000    /* Prevent runaway parsing */ 
15+ 
1216struct  db_sqlite3  {
1317	/* The actual db connection.  */ 
1418	sqlite3  * conn ;
@@ -98,6 +102,21 @@ static const char *db_sqlite3_fmt_error(struct db_stmt *stmt)
98102		       sqlite3_errmsg (conn2sql (stmt -> db -> conn )));
99103}
100104
105+ static  bool  is_strict_constraint_error (struct  db_stmt  * stmt )
106+ {
107+ 	sqlite3  * sql  =  conn2sql (stmt -> db -> conn );
108+ 	const  char  * errmsg  =  sqlite3_errmsg (sql );
109+ 	int  errcode  =  sqlite3_errcode (sql );
110+ 
111+ 	if  (errcode  !=  SQLITE_CONSTRAINT  ||  !stmt -> db -> use_strict_tables )
112+ 		return  false;
113+ 
114+ 	return  (strstr (errmsg , "CHECK constraint failed" ) || 
115+ 		strstr (errmsg , "datatype mismatch" ) || 
116+ 		strstr (errmsg , "cannot store" ) || 
117+ 		strstr (errmsg , "NOT NULL constraint failed" ));
118+ }
119+ 
101120static  bool  db_sqlite3_setup (struct  db  * db , bool  create )
102121{
103122	char  * filename ;
@@ -205,16 +224,183 @@ static bool db_sqlite3_setup(struct db *db, bool create)
205224			   "PRAGMA foreign_keys = ON;" , -1 , & stmt , NULL );
206225	err  =  sqlite3_step (stmt );
207226	sqlite3_finalize (stmt );
208- 	return  err  ==  SQLITE_DONE ;
227+ 
228+ 	if  (err  !=  SQLITE_DONE )
229+ 		return  false;
230+ 
231+ 	bool  is_testing  =  (getenv ("TEST_DB_PROVIDER" ) || 
232+ 			   getenv ("PYTEST_PAR" ) || 
233+ 			   getenv ("TEST_DEBUG" ) || 
234+ 			   getenv ("VALGRIND" ));
235+ 
236+ 	/* SQLite 3.37.0 introduced STRICT table support */ 
237+ 	if  ((db -> developer  ||  is_testing ) &&  sqlite3_libversion_number () >= 3037000 )
238+ 		db -> use_strict_tables  =  true;
239+ 
240+ 	{
241+ 		static  const  char  * security_pragmas [] =  {
242+ 			"PRAGMA trusted_schema = OFF;" ,
243+ 			"PRAGMA cell_size_check = ON;" ,
244+ 			"PRAGMA secure_delete = ON;" ,
245+ 			NULL 
246+ 		};
247+ 
248+ 		for  (int  i  =  0 ; security_pragmas [i ]; i ++ ) {
249+ 			err  =  sqlite3_prepare_v2 (conn2sql (db -> conn ),
250+ 						 security_pragmas [i ], -1 , & stmt , NULL );
251+ 			if  (err  ==  SQLITE_OK ) {
252+ 				err  =  sqlite3_step (stmt );
253+ 				sqlite3_finalize (stmt );
254+ 			}
255+ 		}
256+ 	}
257+ 
258+ 	return  true;
259+ }
260+ 
261+ static  bool  is_standalone_keyword (const  char  * query , const  char  * pos ,
262+ 				       const  char  * keyword , size_t  keyword_len ,
263+ 				       size_t  query_len )
264+ {
265+ 	bool  prefix_ok  =  (pos  ==  query  ||  (!isalnum (pos [-1 ]) &&  pos [-1 ] !=  '_' ));
266+ 	const  char  * after  =  pos  +  keyword_len ;
267+ 	bool  suffix_ok  =  (after  >= query  +  query_len  || 
268+ 			  (!isalnum (after [0 ]) &&  after [0 ] !=  '_' ));
269+ 
270+ 	return  prefix_ok  &&  suffix_ok ;
271+ }
272+ 
273+ static  char  * normalize_types (const  tal_t  * ctx , const  char  * query )
274+ {
275+ 	char  * result ;
276+ 	const  char  * src ;
277+ 	char  * dst ;
278+ 	size_t  query_len ;
279+ 
280+ 	if  (!query )
281+ 		return  NULL ;
282+ 
283+ 	query_len  =  strlen (query );
284+ 
285+ 	#define  MAX_SQL_STATEMENT_LENGTH  1048576 /* 1MB limit */ 
286+ 	if  (query_len  >  MAX_SQL_STATEMENT_LENGTH )
287+ 		return  NULL ;
288+ 
289+ 	/* INT(3) -> INTEGER(7) worst case: +4 bytes per conversion */ 
290+ 	size_t  max_expansions  =  (query_len  / 3 ) *  4 ;
291+ 	size_t  buffer_size  =  query_len  +  max_expansions  +  BUFFER_SAFETY_MARGIN ;
292+ 
293+ 	if  (buffer_size  <  query_len )
294+ 		return  NULL ;
295+ 
296+ 	result  =  tal_arr (ctx , char , buffer_size );
297+ 	src  =  query ;
298+ 	dst  =  result ;
299+ 
300+ 	while  (* src ) {
301+ 		if  (strncasecmp (src , "BIGSERIAL" , 9 ) ==  0  && 
302+ 		    is_standalone_keyword (query , src , "BIGSERIAL" , 9 , query_len )) {
303+ 			strcpy (dst , "INTEGER" );
304+ 			dst  +=  7 ;
305+ 			src  +=  9 ;
306+ 		} else  if  (strncasecmp (src , "VARCHAR" , 7 ) ==  0  && 
307+ 			   is_standalone_keyword (query , src , "VARCHAR" , 7 , query_len )) {
308+ 			strcpy (dst , "TEXT" );
309+ 			dst  +=  4 ;
310+ 			src  +=  7 ;
311+ 
312+ 			if  (* src  ==  '(' ) {
313+ 				const  char  * paren_start  =  src ;
314+ 				while  (* src  &&  * src  !=  ')' ) {
315+ 					src ++ ;
316+ 					/* Prevent runaway on malformed SQL */ 
317+ 					if  (src  -  paren_start  >  MAX_PAREN_SEARCH_DISTANCE )
318+ 						return  NULL ;
319+ 				}
320+ 				if  (* src  ==  ')' ) src ++ ;
321+ 			}
322+ 		} else  if  (strncasecmp (src , "BIGINT" , 6 ) ==  0  && 
323+ 			   is_standalone_keyword (query , src , "BIGINT" , 6 , query_len )) {
324+ 			strcpy (dst , "INTEGER" );
325+ 			dst  +=  7 ;
326+ 			src  +=  6 ;
327+ 		} else  if  (strncasecmp (src , "INT" , 3 ) ==  0  && 
328+ 			   is_standalone_keyword (query , src , "INT" , 3 , query_len )) {
329+ 			strcpy (dst , "INTEGER" );
330+ 			dst  +=  7 ;
331+ 			src  +=  3 ;
332+ 		} else  {
333+ 			* dst ++  =  * src ++ ;
334+ 		}
335+ 	}
336+ 
337+ 	* dst  =  '\0' ;
338+ 	return  result ;
339+ }
340+ 
341+ static  char  * add_strict_keyword (const  tal_t  * ctx , const  char  * query )
342+ {
343+ 	char  * semicolon_pos ;
344+ 	ptrdiff_t  prefix_len ;
345+ 
346+ 	if  (!strcasestr (query , "CREATE TABLE" ))
347+ 		return  tal_strdup (ctx , query );
348+ 
349+ 	if  (strcasestr (query , "STRICT" ))
350+ 		return  tal_strdup (ctx , query );
351+ 
352+ 	semicolon_pos  =  strrchr (query , ';' );
353+ 	if  (!semicolon_pos )
354+ 		semicolon_pos  =  (char  * )query  +  strlen (query );
355+ 
356+ 	prefix_len  =  semicolon_pos  -  query ;
357+ 	return  tal_fmt (ctx , "%.*s STRICT%s" , (int )prefix_len ,
358+ 		       query , semicolon_pos );
359+ }
360+ 
361+ static  char  * prepare_query_for_exec (const  tal_t  * ctx , struct  db  * db ,
362+ 					  const  char  * query )
363+ {
364+ 	char  * normalized_query ;
365+ 
366+ 	normalized_query  =  normalize_types (ctx , query );
367+ 	if  (!normalized_query )
368+ 		return  NULL ;
369+ 
370+ 	if  (db -> use_strict_tables )
371+ 		return  add_strict_keyword (ctx , normalized_query );
372+ 	else 
373+ 		return  normalized_query ;
209374}
210375
211376static  bool  db_sqlite3_query (struct  db_stmt  * stmt )
212377{
213378	sqlite3_stmt  * s ;
214379	sqlite3  * conn  =  conn2sql (stmt -> db -> conn );
215380	int  err ;
381+ 	char  * query_to_execute ;
216382
217- 	err  =  sqlite3_prepare_v2 (conn , stmt -> query -> query , -1 , & s , NULL );
383+ 	query_to_execute  =  prepare_query_for_exec (stmt , stmt -> db ,
384+ 						       stmt -> query -> query );
385+ 	bool  should_free_query  =  (query_to_execute  !=  stmt -> query -> query );
386+ 
387+ 	err  =  sqlite3_prepare_v2 (conn , query_to_execute , -1 , & s , NULL );
388+ 
389+ 	if  (err  !=  SQLITE_OK ) {
390+ 		if  (should_free_query )
391+ 			tal_free (query_to_execute );
392+ 		tal_free (stmt -> error );
393+ 		if  (is_strict_constraint_error (stmt )) {
394+ 			stmt -> error  =  tal_fmt (stmt , "%s (Note: STRICT tables are enabled)" ,
395+ 					      db_sqlite3_fmt_error (stmt ));
396+ 		} else  {
397+ 			stmt -> error  =  db_sqlite3_fmt_error (stmt );
398+ 		}
399+ 		return  false;
400+ 	}
401+ 
402+ 	if  (should_free_query )
403+ 		tal_free (query_to_execute );
218404
219405	for  (size_t  i = 0 ; i < stmt -> query -> placeholders ; i ++ ) {
220406		struct  db_binding  * b  =  & stmt -> bindings [i ];
@@ -246,12 +432,6 @@ static bool db_sqlite3_query(struct db_stmt *stmt)
246432		}
247433	}
248434
249- 	if  (err  !=  SQLITE_OK ) {
250- 		tal_free (stmt -> error );
251- 		stmt -> error  =  db_sqlite3_fmt_error (stmt );
252- 		return  false;
253- 	}
254- 
255435	stmt -> inner_stmt  =  s ;
256436	return  true;
257437}
@@ -270,7 +450,12 @@ static bool db_sqlite3_exec(struct db_stmt *stmt)
270450	err  =  sqlite3_step (stmt -> inner_stmt );
271451	if  (err  !=  SQLITE_DONE ) {
272452		tal_free (stmt -> error );
273- 		stmt -> error  =  db_sqlite3_fmt_error (stmt );
453+ 		if  (is_strict_constraint_error (stmt )) {
454+ 			stmt -> error  =  tal_fmt (stmt , "%s (Note: STRICT tables are enabled)" ,
455+ 					      db_sqlite3_fmt_error (stmt ));
456+ 		} else  {
457+ 			stmt -> error  =  db_sqlite3_fmt_error (stmt );
458+ 		}
274459		return  false;
275460	}
276461
0 commit comments