data engineering

pyspark Cannot cast DOCUMENT into a NullType 관련.

qkqhxla1 2021. 9. 20. 19:24

pyspark로 몽고에서 컬렉션을 읽어와 잡을 돌리고 있다. 몽고에서 컬렉션을 읽어와서 title이라는 필드에서 \t,\r,\n을 제거하고 이후의 프로세싱을 하기 위한 작업이다.

.......
df = self.spark.read.format("com.mongodb.spark.sql.DefaultSource") \
    .option("spark.mongodb.input.partitioner", "MongoSamplePartitioner") \
    .option("spark.mongodb.input.partitionerOptions.partitionKey", "_id") \
    .option("spark.mongodb.input.partitionerOptions.partitionSizeMB", 128) \
    .option("spark.mongodb.input.partitionerOptions.samplesPerPartition", 10) \
    .option("spark.mongodb.input.uri", "{}?authSource=admin".format(mongo_uri_format)) \
    .load()

column_list = ['category', 'title', '_id']
df.printSchema()
df.limit(10).select(*column_list).withColumn('title', regexp_replace('title', '\t|\r|\n', '')).show(20, truncate=False)

근데 위의 코드는 잘못된게 없는데(다른 컬렉션 대상으로는 잘 작동함.) 아래와 같은 에러가 발생하였다.

[2021-09-19 14:37:07,960] {subprocess.py:78} INFO - 21/09/19 14:37:07 WARN TaskSetManager: Lost task 1.0 in stage 1.0 (TID 2, x.x.x.x, executor 3): org.apache.spark.SparkException: Task failed while writing rows.
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:291)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$write$15(FileFormatWriter.scala:205)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at org.apache.spark.scheduler.Task.run(Task.scala:127)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:446)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:449)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at java.lang.Thread.run(Thread.java:748)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - Caused by: com.mongodb.spark.exceptions.MongoTypeConversionException: Cannot cast DOCUMENT into a NullType .......
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at com.mongodb.spark.sql.MapFunctions$.convertToDataType(MapFunctions.scala:220)
[2021-09-19 14:37:07,961] {subprocess.py:78} INFO - 	at com.mongodb.spark.sql.MapFunctions$.$anonfun$documentToRow$1(MapFunctions.scala:37)

Cannot cast DOCUMENT into a NullType에 관련해서 검색을 좀 해봤는데 명확한 해답이 나오지 않다가 발견했다.
https://www.titanwolf.org/Network/q/067f183b-6831-45fd-99a3-cb4607c62604/y의 answer 2의 경우였는데, 이걸 읽어보면 MongoDB Connector for Spark는 스키마를 만들기 위해서 1000개의 샘플을 가져와서 데이터를 읽어보고, 한 필드에서 null의 값이 일정 비율이 넘으면 null로 판단한다고 한다.

여러모로 삽질해본결과 spark.read.format().load시에 존재하는 데이터를 사용해서 판단하는것 같다. 내가 실제로는 사용할 컬럼만 column_list로 선언해놓고, 이 필드들에는 널값이 존재하지 않고 이것만 가져와서 사용한다고 해도 그 이전에 다른 필드중 하나에 널값이 존재하는 비율이 일정 이상을 넘어가서 에러가 발생하는것이다.

위의 스택오버플로우 글에서는 sampleSize를 더 크게 설정해서 문제를 해결하라고 나와있는데 문제는 내 모든 도큐먼트의 특정 필드에 널값의 비율이 애초에 너무 높아서 sampleSize를 큰 값을 줘도 해결이 되는 문제가 아니라는거다. 그래서 찾다찾다 생각한게 애초에 몽고에서 가져올때 내가 사용할 column_list만 가져오면 괜찮지않을까? 생각했는데 그렇게 하니 해결됐다.

pipeline = "{$project: {category: 1 title: 1, _id: 1}}"
df = self.spark.read.format("com.mongodb.spark.sql.DefaultSource") \
    .option("spark.mongodb.input.partitioner", "MongoSamplePartitioner") \
    .option("spark.mongodb.input.partitionerOptions.partitionKey", "_id") \
    .option("spark.mongodb.input.partitionerOptions.partitionSizeMB", 128) \
    .option("spark.mongodb.input.partitionerOptions.samplesPerPartition", 10) \
    .option("spark.mongodb.input.uri", "{}?authSource=admin".format(mongo_uri_format)) \
    .option("pipeline", pipeline).load()

column_list = ['category', 'title', '_id']
df.printSchema()
df.limit(10).select(*column_list).withColumn('title', regexp_replace('title', '\t|\r|\n', '')).show(20, truncate=False)

mongodb find에서 특정 필드만 가져오는 법 : https://docs.mongodb.com/manual/tutorial/project-fields-from-query-results/
mongodb aggregate에서 위처럼 특정 필드만 가져오는 법 : https://docs.mongodb.com/manual/reference/operator/aggregation/project/