diff --git a/CHANGELOG b/CHANGELOG index 464675a4..13e157aa 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,9 @@ # spark-redshift Changelog +## 4.1.0 + +- Add `include_column_list` parameter + ## 4.0.2 - Trim SQL text for preactions and postactions, to fix empty SQL queries bug. diff --git a/README.md b/README.md index 433a5110..4d006aeb 100644 --- a/README.md +++ b/README.md @@ -619,6 +619,16 @@ must also set a distribution key with the distkey option.

Since setting usestagingtable=false operation risks data loss / unavailability, we have chosen to deprecate it in favor of requiring users to manually drop the destination table themselves.

+ + include_column_list + No + false + + If true then this library will automatically extract the columns from the schema + and add them to the COPY command according to the Column List docs. + (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`). + + description No diff --git a/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala b/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala index b2ab93f8..d0c63782 100644 --- a/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala +++ b/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala @@ -38,7 +38,8 @@ private[redshift] object Parameters { "diststyle" -> "EVEN", "usestagingtable" -> "true", "preactions" -> ";", - "postactions" -> ";" + "postactions" -> ";", + "include_column_list" -> "false" ) val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP") @@ -285,5 +286,11 @@ private[redshift] object Parameters { new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken)) } } + + /** + * If true then this library will extract the column list from the schema to + * include in the COPY command (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`) + */ + def includeColumnList: Boolean = parameters("include_column_list").toBoolean } } diff --git a/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala b/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala index 32dd5162..ad0994f6 100644 --- a/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala +++ b/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala @@ -86,6 +86,7 @@ private[redshift] class RedshiftWriter( */ private def copySql( sqlContext: SQLContext, + schema: StructType, params: MergedParameters, creds: AWSCredentialsProvider, manifestUrl: String): String = { @@ -96,7 +97,13 @@ private[redshift] class RedshiftWriter( case "AVRO" => "AVRO 'auto'" case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'" } - s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " + + val columns = if (params.includeColumnList) { + "(" + schema.fieldNames.map(name => s""""$name"""").mkString(",") + ") " + } else { + "" + } + + s"COPY ${params.table.get} ${columns}FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " + s"${format} manifest ${params.extraCopyOptions}" } @@ -138,7 +145,7 @@ private[redshift] class RedshiftWriter( manifestUrl.foreach { manifestUrl => // Load the temporary data into the new file - val copyStatement = copySql(data.sqlContext, params, creds, manifestUrl) + val copyStatement = copySql(data.sqlContext, data.schema, params, creds, manifestUrl) log.info(copyStatement) try { jdbcWrapper.executeInterruptibly(conn.prepareStatement(copyStatement)) diff --git a/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala b/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala index faf5bc4c..b69c1a00 100644 --- a/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala +++ b/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala @@ -28,7 +28,8 @@ class ParametersSuite extends FunSuite with Matchers { "tempdir" -> "s3://foo/bar", "dbtable" -> "test_schema.test_table", "url" -> "jdbc:redshift://foo/bar?user=user&password=password", - "forward_spark_s3_credentials" -> "true") + "forward_spark_s3_credentials" -> "true", + "include_column_list" -> "true") val mergedParams = Parameters.mergeParameters(params) @@ -37,9 +38,14 @@ class ParametersSuite extends FunSuite with Matchers { mergedParams.jdbcUrl shouldBe params("url") mergedParams.table shouldBe Some(TableName("test_schema", "test_table")) assert(mergedParams.forwardSparkS3Credentials) + assert(mergedParams.includeColumnList) // Check that the defaults have been added - (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials").foreach { + ( + Parameters.DEFAULT_PARAMETERS + - "forward_spark_s3_credentials" + - "include_column_list" + ).foreach { case (key, value) => mergedParams.parameters(key) shouldBe value } } diff --git a/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala index 10fb7a93..deaab14d 100644 --- a/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala @@ -442,6 +442,46 @@ class RedshiftSourceSuite mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) } + test("include_column_list=true adds the schema columns to the COPY query") { + val expectedCommands = Seq( + "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r, + + ("COPY \"PUBLIC\".\"test_table\" \\(\"testbyte\",\"testbool\",\"testdate\"," + + "\"testdouble\",\"testfloat\",\"testint\",\"testlong\",\"testshort\",\"teststring\"," + + "\"testtimestamp\"\\) FROM .*").r + ) + + val params = defaultParams ++ Map("include_column_list" -> "true") + + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> TestUtils.testSchema)) + + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + source.createRelation(testSqlContext, SaveMode.Append, params, expectedDataDF) + + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) + } + + test("include_column_list=false (default) does not add the schema columns to the COPY query") { + val expectedCommands = Seq( + "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r, + + "COPY \"PUBLIC\".\"test_table\" FROM .*".r + ) + + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> TestUtils.testSchema)) + + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + source.createRelation(testSqlContext, SaveMode.Append, defaultParams, expectedDataDF) + + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) + } + test("configuring maxlength on string columns") { val longStrMetadata = new MetadataBuilder().putLong("maxlength", 512).build() val shortStrMetadata = new MetadataBuilder().putLong("maxlength", 10).build() @@ -594,4 +634,4 @@ class RedshiftSourceSuite } assert(e.getMessage.contains("Block FileSystem")) } -} \ No newline at end of file +} diff --git a/version.sbt b/version.sbt index cac72218..da6fe0f4 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "4.0.2" +version in ThisBuild := "4.1.0"