data engineering

pyspark read mongodb, mysql

qkqhxla1 2021. 12. 28. 14:26

이전에 이거 관련해서 글을 썼었는데.. 너무 뒤죽박죽한 글 구성 + 잘 모르는데 여기저기서 이상하게 갖다붙힘 + 버전이 낮아짐에 따라 쓸모 없어진 글이 되버려서 이전 글은 삭제하고 다시 좀 다듬어서 정리합니다. 현재 쓰고있는 spark 3 버전 초반 기준입니다.


1. mongodb
공식 api : https://docs.mongodb.com/spark-connector/current/python-api/
몽고디비에서 데이터를 읽는 예시를 예로 듦. https://docs.mongodb.com/spark-connector/current/python/read-from-mongodb/
공식 홈페이지에 코드 예제가 있는데..

pipeline = "{'$match': {'type': 'apple'}}"
df = spark.read.format("mongo").option("pipeline", pipeline).load()
df.show()

스파크 환경에 익숙하지 않고서는 이해하기에 조금 부족한 예제다. 예로 스파크는 잡을 병렬적으로(parallel) 나눠서 빠르게 돌리는게 장점인데 위 코드는 그럼 몽고디비에서 큰 컬렉션을 가져온다고 가정할 때 어떻게 나눠서 가져올까?? 나도 처음에 생각하기로는 '몽고에서 제공하는 스파크 모듈이니까 어떻게 알아서 내부적으로 잘 하겠지?' 싶어서 신경을 안쓰고 지냈는데 이렇게 단순하게 생각할 경우 다른 디비에서 돌릴 경우 잘못 생각할 우려가 있다. 옵션을 하나만 더 알아두자.

https://docs.mongodb.com/spark-connector/current/configuration/에서 보이는 partitionKey 옵션이다.

spark = SparkSession.builder.appName("myApp").getOrCreate()
mongo_uri_format = ''
pipeline = "{'$match': {'type': 'apple'}}"
spark.read.format("mongo") \
            .option("spark.mongodb.input.partitionerOptions.partitionKey", "_id") \
            .option("spark.mongodb.input.uri", "{}?authSource=admin".format(mongo_uri_format)) \
            .option("pipeline", pipeline).load()

조금 풀어쓰자면 이처럼 된다. uri는 몽고에 접속하는 uri고 partitionKey만 명시적으로 적어주었는데 partitionKey를 기준으로 컬렉션을 쪼개서 가져온다. 몽고에 도큐먼스를 insert할때 자동으로 _id컬럼이 추가되고, _id컬럼은 인덱스가 자동으로 잡혀있는데 이를 이용해 잘 쪼개서 가져온다. 위처럼 Spark mongo connector를 사용할때는 
https://docs.mongodb.com/spark-connector/current/python-api/에 나온대로 

./bin/pyspark --packages org.mongodb.spark:mongo-spark-connector_2.12:3.0.1 

스파크 실행시 패키지를 추가해줘야 한다. 위 코드에서 spark.mongodb.input.uri를 옵션으로 줬는데, 스파크 실행시 옵션으로 줘도 된다.



2. mysql
위에 몽고디비 예시를 적을때 스파크가 어떻게 몽고디비에서 나눠서 가져오는지를 알아두자고 했는데 mysql처럼 컬럼이 정해져서 insert시에 정해진 컬럼만 들어가는 경우, 내가 따로 파티셔닝용 컬럼을 지정해놓지 않는 경우 partitioning할만한 컬럼이 없다.(몽고는 insert시 자동으로 _id컬럼이 추가되지만 mysql은 내가 만드는게 아닌이상 자동으로 추가되지 않음.) 일단 pyspark 에서 jdbc로 가져오는 mysql예제이다. (spark 3.1.x, mysql 8까지 호환될듯.)

from pyspark.sql import SparkSession
.........................
spark = SparkSession.builder.appName("myApp").getOrCreate()
...........................
# serverTimezone은 UTC로 해줘야함.. KST는 인식 못하는 버그가 있음. useSSL을 설정해주지 않으면 warning이 뜸.
df = spark.read.format("jdbc") \
    .option("driver", "com.mysql.cj.jdbc.Driver") \
    .option("url", "jdbc:mysql://{}/{}?serverTimezone=UTC&useSSL=false".format(SERVER, DB_NAME)) \
    .option("dbtable", "mytable") \
    .option("user", user) \
    .option("password", password) \
    .load()

이것도 실행시 packages추가를 해줘야 한다.

./bin/pyspark --packages mysql:mysql-connector-java:8.0.11

기본적으로 포함해야 하는 옵션은 넣었다. 근데 이런식으로만 설정하면 위에서 계속 말했지만 partitioning할만한 컬럼이 없어서 spark에서 mysql table을 가져올에 병렬적으로 테이블을 읽어오지 않는다.(제가 뭔가 틀렸으면 적어주세요.) 
테이블을 병렬적으로 나눠서 가져올수 없어서 테이블 크기가 큰 경우 하나의 익스큐터에서 테이블을 한번에 다 읽어오려다가 자바 힙 메모리가 터져버린다.

그래서 메모리가 터지지 않게 제대로 읽어오려면 partitionColumn, lowerBound, upperBound, numPartitions 4개 옵션을 설정해줘야 한다.
https://charsyam.wordpress.com/2021/09/06/%EC%9E%85-%EA%B0%9C%EB%B0%9C-spark-%EC%97%90%EC%84%9C-database-%EB%B9%A8%EB%A6%AC-%EB%8D%A4%ED%94%84%ED%95%98%EB%8A%94-%EB%B2%95parallelism/

문제는 오래된 테이블의 경우 크기가 큰 경우에 파티셔닝으로 사용할만한 컬럼이 없는 경우인데.. 이경우 alter table로 auto increment속성이 있는 컬럼을 추가하자니 너무 오래걸릴것 같아 애매한 경우가 있다. 어떻게든 잘라서 가져오려면 테이블에 인덱스가 존재하는 컬럼으로 where조건을 걸어서 잘라서 가져오는 방법이 있다.
아래처럼 dbtable옵션으로 넘길때 조건을 걸수 있다.

from pyspark.sql import SparkSession, DataFrame
from functools import reduce

..................

df_list = []
for column in ['a_value', 'b_value']:  # 필터링걸어서 가져옴. 한번에 가져오면 메모리가 터지는데 나눠서 가져오면 안 터진다.
    query = information_dict['query'].format(site)
    # serverTimezone은 UTC로 해줘야함.. KST는 인식 못하는 버그가 있음.
    df = self.spark.read.format("jdbc") \
        .option("driver", "com.mysql.cj.jdbc.Driver") \
        .option("url", "jdbc:mysql://{}/{}?serverTimezone=UTC&useSSL=false".format(SERVER, DB_NAME)) \
        .option("dbtable", '(select * from my_table where condition="{}") as my_table'.format(column)) \
        .option("user", self.user) \
        .option("password", self.password) \
        .load()
    df_list.append(df)
df = reduce(DataFrame.unionAll, df_list)  # 사이트별로 쪼개서 가져온 dataframe을 전부 하나로 합침