2023-01-19

AWS Redshift Data Import in Python

Had done nothing for this blog for a long long long time, primarily because I changed my job from a technical role to a hybrid role.

Now just be quick, let's assume you have some large datasets need to be loaded into AWS Redshift. Obviously the best practice from AWS is load the file to S3 bucket, then COPY the file from S3 to Redshift.

Looks so simple huh? Now what if you are just an analytics guy, would your IT department provision a S3 bucket to you? would you be given COPY permission? Would you request be lodged as a Jira ticket and need to wait for a full sprint?

So let's try another way, basically maximum batch file size in AWS Redshift is about 16MB, so the workaround here is we can generate a batch insert statement, e.g. INSERT INTO TABLE(col1, col2...) VALUES (val1, val2...) , ... But to control the size we need a pagination function. When we have a pagination function, we want to have a configuration file for flexibility. And somehow a log function would be great if any issue happens. So below is the list of requirements:

  • source flat file
  • destination Redshift table name
  • configuration file to list source file delimiter, quotation, and source columns to destination column mapping
  • page size
for the configuration file, below is an example:

{
"delimiter":",",
"quotation":"\"",
"columns":[
{
"flat col1":"table col1",
"flat col2":"table col2"
}
]
}

And the python script is here:



#%%
from sqlalchemy import create_engine
import csv
import json
import logging
import uuid
from datetime import datetime
from urllib.parse import quote

username = "" #<= your redshift username
password = "" #<= your redshift password
rsServer = "" #<= your redshift instance address

engineRedshift = create_engine("postgresql://{}:{}@{}".format(quote(username), quote(password)), rsServer)

destSchema = "" #<= destination redshift schema
destTable = "" #<= destination redshift table
srcFile = "" #<= source file 
configFile = "" #<= configuration json file
pageSize = 10000 #<= page size

if (username == "") or (password == ""):
    raise Exception("Must provide username/password")
#%%
# provide a guid for each execution in the log file
uid = uuid.uuid4()

logFileName = "\\".join(srcFile.split("\\")[:-1]) + "\\" + "import_{}_{}.log".format(
    "".join(srcFile.split("\\")[-1:][0].split(".")[0:1]), 
    datetime.now().strftime("%Y%m%d"))

logging.basicConfig(
    filename=logFileName, 
    format="{} - %(asctime)s - %(message)s".format(uid), 
    level=logging.INFO
    )

#%%
logging.info("source file: {}".format(srcFile))
logging.info("destination table: {}.{}".format(destSchema, destTable))
logging.info("configuration file: {}".format(configFile))
logging.info("pagination setting: {}".format(pageSize))

#%%
# step 1: parse configuration to get delimiter, quotation, source columns and dest columns
with open(configFile, "r") as f:
    s = json.load(f)

delimiter = s["delimiter"]
quotation = s["quotation"]

lstSrc = []
lstDest = []

for key in s["columns"][0]:
    lstSrc.append(key)
    lstDest.append(s["columns"][0][key])

lstSrc = [c.lower() for c in lstSrc]
lstdesc = [c.lower() for c in lstDest]

#%%
logging.info("configuration delimiter: {}".format(delimiter))
logging.info("configuration quotation: {}".format("NA" if quotation == "" else quotation))
logging.info("configuration source columns: {}".format(lstSrc))
logging.info("configuration destination columns: {}".format(lstDest))

#%%
# step 2: all source columns should be in source csv file, read file header
# for extra columns in the source file, put column position in lstPop to drop
with open(srcFile, newline="") as flatFile:
    if quotation != "":
        reader = csv.reader(flatFile, delimiter=delimiter, quotechar=quotation)
    else: 
        reader = csv.reader(flatFile, delimiter=delimiter)
    header = next(reader)

header = [c.lower() for c in header]

lstPop =[]
for c in [c for c in header if c not in lstSrc]:
    lstPop.append(header.index(c))

# lstPop

logging.info("Columns to Drop from Source: {}".format(lstPop))

#%%
# step 3: read dest table schema
# if dest column in configuration file is not in redshift, raise error directly
# just like MSSQL import wizard, if validation raises error, simple stop the process
# note all destination columns need to be NULLABLE = YES, as we turn empty string in csv to NULL
sql = """select lower(column_name) as column_name, lower(is_nullable) as is_nullable
from information_schema.columns
where table_schema ilike '{}'
    and table_name ilike '{}'
order by ordinal_position""".format(destSchema, destTable)

lstR = [r for r in engineRedshift.execute(sql)]

lstCols = [r[0] for r in lstR]
for c in lstDest:
    if c not in lstCols:
        raise Exception("{} not in dest table".format(c))

lstNullable = [r[0] for r in lstR if r[1] == "no"]
if len(lstNullable) > 0:
    raise Exception("column(s) {} needs to be NULLable".format(",".join(lstNullable)))

logging.info("Verify destination columns: Passed")
#%%
# function to construct value list for insert statement
# fObj => file object
# lst => columns to drop
# min => start row
# max => end row
# d => delimiter
# q => quotation
def construct_string_sql_by_fObj(fObj, lst, min, max, d, q):
    value_list = ""
    
    try:
        if q != "":
            reader = csv.reader(fObj, delimiter=d, quotechar=q)
        else:
            reader = csv.reader(fObj, delimiter=d)
        page = [row for idx, row in enumerate(reader) if (idx >= min) & (idx < max) ]
        
        for row in page:
            if len(lst) > 0:
                for i in sorted(lst, reverse=True):
                    del row[i]

            row = [r.replace("'", "''") for r in row]
            s = "('" + "','".join(row) + "'),"
            value_list +=s
    except:
        return ""

    return value_list[:-1].replace("'',", "NULL,")


#%%
theFile = open(srcFile, "r").readlines()
conn = engineRedshift.connect()

for i in range(0, int(len(theFile)/pageSize) + 1):    
    min = i * pageSize + 1
    max = (i * pageSize + pageSize if i * pageSize + pageSize < len(theFile) else len(theFile)) + 1
    
    vlst = construct_string_sql_by_fObj(theFile, lstPop, min, max, delimiter, quotation)
    
    # note: file row count != count of rows to insert because newline could found
    # the pagination is calculated by file row count so if empty page found all remaining pages can be skipped
    if vlst == "":
        break

    stmt = "insert into {}.{}({}) Values".format(destSchema, destTable, ",".join(lstDest))
    stmt += vlst
    
    print("from {} to {}".format(min, max - 1))    
    # print(stmt)

    
    conn.execute(stmt)

    logging.info("Data import start: from {} to {}".format(min, max - 1))
    
conn.close()
print("import completed")

logging.info("File import completed\n\n\n\n")
        
# %%



Enjoy!