diff --git a/CHANGELOG b/CHANGELOG
index 464675a4..13e157aa 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,5 +1,9 @@
# spark-redshift Changelog
+## 4.1.0
+
+- Add `include_column_list` parameter
+
## 4.0.2
- Trim SQL text for preactions and postactions, to fix empty SQL queries bug.
diff --git a/README.md b/README.md
index 433a5110..4d006aeb 100644
--- a/README.md
+++ b/README.md
@@ -619,6 +619,16 @@ must also set a distribution key with the distkey option.
Since setting usestagingtable=false operation risks data loss / unavailability, we have chosen to deprecate it in favor of requiring users to manually drop the destination table themselves.
+
+ | include_column_list |
+ No |
+ false |
+
+ If true then this library will automatically extract the columns from the schema
+ and add them to the COPY command according to the Column List docs.
+ (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`).
+ |
+
| description |
No |
diff --git a/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala b/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala
index b2ab93f8..d0c63782 100644
--- a/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala
+++ b/src/main/scala/io/github/spark_redshift_community/spark/redshift/Parameters.scala
@@ -38,7 +38,8 @@ private[redshift] object Parameters {
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
"preactions" -> ";",
- "postactions" -> ";"
+ "postactions" -> ";",
+ "include_column_list" -> "false"
)
val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP")
@@ -285,5 +286,11 @@ private[redshift] object Parameters {
new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken))
}
}
+
+ /**
+ * If true then this library will extract the column list from the schema to
+ * include in the COPY command (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`)
+ */
+ def includeColumnList: Boolean = parameters("include_column_list").toBoolean
}
}
diff --git a/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala b/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala
index 32dd5162..ad0994f6 100644
--- a/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala
+++ b/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriter.scala
@@ -86,6 +86,7 @@ private[redshift] class RedshiftWriter(
*/
private def copySql(
sqlContext: SQLContext,
+ schema: StructType,
params: MergedParameters,
creds: AWSCredentialsProvider,
manifestUrl: String): String = {
@@ -96,7 +97,13 @@ private[redshift] class RedshiftWriter(
case "AVRO" => "AVRO 'auto'"
case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'"
}
- s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
+ val columns = if (params.includeColumnList) {
+ "(" + schema.fieldNames.map(name => s""""$name"""").mkString(",") + ") "
+ } else {
+ ""
+ }
+
+ s"COPY ${params.table.get} ${columns}FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
s"${format} manifest ${params.extraCopyOptions}"
}
@@ -138,7 +145,7 @@ private[redshift] class RedshiftWriter(
manifestUrl.foreach { manifestUrl =>
// Load the temporary data into the new file
- val copyStatement = copySql(data.sqlContext, params, creds, manifestUrl)
+ val copyStatement = copySql(data.sqlContext, data.schema, params, creds, manifestUrl)
log.info(copyStatement)
try {
jdbcWrapper.executeInterruptibly(conn.prepareStatement(copyStatement))
diff --git a/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala b/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala
index faf5bc4c..b69c1a00 100644
--- a/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala
+++ b/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala
@@ -28,7 +28,8 @@ class ParametersSuite extends FunSuite with Matchers {
"tempdir" -> "s3://foo/bar",
"dbtable" -> "test_schema.test_table",
"url" -> "jdbc:redshift://foo/bar?user=user&password=password",
- "forward_spark_s3_credentials" -> "true")
+ "forward_spark_s3_credentials" -> "true",
+ "include_column_list" -> "true")
val mergedParams = Parameters.mergeParameters(params)
@@ -37,9 +38,14 @@ class ParametersSuite extends FunSuite with Matchers {
mergedParams.jdbcUrl shouldBe params("url")
mergedParams.table shouldBe Some(TableName("test_schema", "test_table"))
assert(mergedParams.forwardSparkS3Credentials)
+ assert(mergedParams.includeColumnList)
// Check that the defaults have been added
- (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials").foreach {
+ (
+ Parameters.DEFAULT_PARAMETERS
+ - "forward_spark_s3_credentials"
+ - "include_column_list"
+ ).foreach {
case (key, value) => mergedParams.parameters(key) shouldBe value
}
}
diff --git a/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala
index 10fb7a93..deaab14d 100644
--- a/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala
+++ b/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftSourceSuite.scala
@@ -442,6 +442,46 @@ class RedshiftSourceSuite
mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
}
+ test("include_column_list=true adds the schema columns to the COPY query") {
+ val expectedCommands = Seq(
+ "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r,
+
+ ("COPY \"PUBLIC\".\"test_table\" \\(\"testbyte\",\"testbool\",\"testdate\"," +
+ "\"testdouble\",\"testfloat\",\"testint\",\"testlong\",\"testshort\",\"teststring\"," +
+ "\"testtimestamp\"\\) FROM .*").r
+ )
+
+ val params = defaultParams ++ Map("include_column_list" -> "true")
+
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> TestUtils.testSchema))
+
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ source.createRelation(testSqlContext, SaveMode.Append, params, expectedDataDF)
+
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
+ }
+
+ test("include_column_list=false (default) does not add the schema columns to the COPY query") {
+ val expectedCommands = Seq(
+ "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r,
+
+ "COPY \"PUBLIC\".\"test_table\" FROM .*".r
+ )
+
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> TestUtils.testSchema))
+
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ source.createRelation(testSqlContext, SaveMode.Append, defaultParams, expectedDataDF)
+
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
+ }
+
test("configuring maxlength on string columns") {
val longStrMetadata = new MetadataBuilder().putLong("maxlength", 512).build()
val shortStrMetadata = new MetadataBuilder().putLong("maxlength", 10).build()
@@ -594,4 +634,4 @@ class RedshiftSourceSuite
}
assert(e.getMessage.contains("Block FileSystem"))
}
-}
\ No newline at end of file
+}
diff --git a/version.sbt b/version.sbt
index cac72218..da6fe0f4 100644
--- a/version.sbt
+++ b/version.sbt
@@ -1 +1 @@
-version in ThisBuild := "4.0.2"
+version in ThisBuild := "4.1.0"