diff --git a/preprocessing/process_sql.py b/preprocessing/process_sql.py index 686ee1f..63b48ae 100644 --- a/preprocessing/process_sql.py +++ b/preprocessing/process_sql.py @@ -29,9 +29,10 @@ import json import sqlite3 from nltk import word_tokenize +import pdb CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') -JOIN_KEYWORDS = ('join', 'on', 'as') +JOIN_KEYWORDS = ('left', 'join', 'on', 'as') WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') UNIT_OPS = ('none', '-', '+', "*", '/') @@ -397,10 +398,15 @@ def parse_from(toks, start_idx, tables_with_alias, schema): idx, sql = parse_sql(toks, idx, tables_with_alias, schema) table_units.append((TABLE_TYPE['sql'], sql)) else: - if idx < len_ and toks[idx] == 'join': + join_type = None + if idx < len_ - 1 and toks[idx] == 'left' and toks[idx+1] == 'join': + join_type = 'left' + idx += 2 # skip left join + elif idx < len_ and toks[idx] == 'join': + join_type = 'inner' idx += 1 # skip join idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) - table_units.append((TABLE_TYPE['table_unit'],table_unit)) + table_units.append((TABLE_TYPE['table_unit'],table_unit,join_type)) default_tables.append(table_name) if idx < len_ and toks[idx] == "on": idx += 1 # skip on