diff --git a/build.sbt b/build.sbt index 173a074c..ecafc7bc 100644 --- a/build.sbt +++ b/build.sbt @@ -9,9 +9,21 @@ lazy val isFipsRelease = { // scalastyle:on println result } +lazy val includeFatJarsAndBundles = { + // When the PUBLISH environment variable is true, we assume the caller is publishing to Maven, + // in which case we do not want to include fat JAR, ZIP, or Tarball bundle artifacts. + val result = !sys.env.getOrElse("PUBLISH", "false").toBoolean + // scalastyle:off println + println(s"Including Fat JARs and Bundles in published artifacts: $result") + // scalastyle:on println + result +} +def isFatJarOrBundle(c: String): Boolean = + c.contains("with-dependencies") || c.contains("fat-test") || c.contains("bundle") + lazy val snowparkName = s"snowpark${if (isFipsRelease) "-fips" else ""}" lazy val jdbcName = s"snowflake-jdbc${if (isFipsRelease) "-fips" else ""}" -lazy val snowparkVersion = "1.17.0-SNAPSHOT" +lazy val snowparkVersion = "1.17.0" lazy val Javadoc = config("genjavadoc") extend Compile @@ -26,6 +38,7 @@ lazy val javadocSettings = inConfig(Javadoc)(Defaults.configSettings) ++ Seq( !(s.getParent.contains("internal") || s.getParent.contains("Internal"))), Javadoc / javacOptions := Seq( "--allow-script-in-comments", + "-use", "-windowtitle", s"Snowpark Java API Reference $snowparkVersion", "-doctitle", s"Snowpark Java API Reference $snowparkVersion", "-header", s"""
@@ -165,13 +178,11 @@ lazy val root = (project in file(".")) // Release settings - // Release JAR including compiled test classes - Test / packageBin / publishArtifact := true, - // Also publish a test-sources JAR - Test / packageSrc / publishArtifact := true, - Test / packageSrc / artifact := - (Compile / packageSrc / artifact).value.withClassifier(Some("tests-sources")), - addArtifact(Test / packageSrc / artifact, Test / packageSrc), + // Disable publishing the source files JAR unless publishing to maven. + Compile / packageSrc / publishArtifact := !includeFatJarsAndBundles, + + // Disable publishing test source files in all scenarios. + Test / packageSrc / publishArtifact := false, // Fat JAR settings assembly / assemblyJarName := @@ -247,6 +258,16 @@ lazy val root = (project in file(".")) Artifact(name = snowparkName, `type` = "bundle", extension = "tar.gz", classifier = "bundle"), Universal / packageZipTarball), + // Explicitly list checksum files to be generated for visibility + checksums := Seq("md5", "sha1"), + + // Filter out bundles and fat jars if publishing to maven + artifacts := artifacts.value filter ( + a => includeFatJarsAndBundles || !isFatJarOrBundle(a.classifier.getOrElse(""))), + packagedArtifacts := packagedArtifacts.value filter ( + af => includeFatJarsAndBundles || !isFatJarOrBundle(af._1.classifier.getOrElse(""))), + + // Signed publish settings credentials += Credentials(Path.userHome / ".ivy2" / ".credentials"), // Set up GPG key for release build from environment variable: GPG_HEX_CODE // Build jenkins job must have set it, otherwise, the release build will fail. @@ -256,7 +277,6 @@ lazy val root = (project in file(".")) Properties.envOrNone("GPG_HEX_CODE").getOrElse("Jenkins_build_not_set_GPG_HEX_CODE"), "ignored" // this field is ignored; passwords are supplied by pinentry ), - // usePgpKeyHex(Properties.envOrElse("GPG_SIGNATURE", "12345")), Global / pgpPassphrase := Properties.envOrNone("GPG_KEY_PASSPHRASE").map(_.toCharArray), publishMavenStyle := true, releaseCrossBuild := true, diff --git a/scripts/deploy-common.sh b/scripts/deploy-common.sh new file mode 100644 index 00000000..3670cb33 --- /dev/null +++ b/scripts/deploy-common.sh @@ -0,0 +1,130 @@ +#!/bin/bash -ex +# +# DO NOT RUN DIRECTLY. +# Script must be sourced by deploy.sh or deploy-fips.sh +# after setting or unsetting `SNOWPARK_FIPS` environment variable. +# + +if [ -z "$GPG_KEY_ID" ]; then + export GPG_KEY_ID="Snowflake Computing" + echo "[WARN] GPG key ID not specified, using default: $GPG_KEY_ID." +fi + +if [ -z "$GPG_KEY_PASSPHRASE" ]; then + echo "[ERROR] GPG passphrase is not specified for $GPG_KEY_ID!" + exit 1 +fi + +if [ -z "$GPG_PRIVATE_KEY" ]; then + echo "[ERROR] GPG private key file is not specified!" + exit 1 +fi + +if [ -z "$sonatype_user" ]; then + echo "[ERROR] Jenkins sonatype user is not specified!" + exit 1 +fi + +if [ -z "$sonatype_password" ]; then + echo "[ERROR] Jenkins sonatype pwd is not specified!" + exit 1 +fi + +if [ -z "$PUBLISH" ]; then + echo "[ERROR] 'PUBLISH' is not specified!" + exit 1 +fi + +if [ -z "$github_version_tag" ]; then + echo "[ERROR] 'github_version_tag' is not specified!" + exit 1 +fi + +mkdir -p ~/.ivy2 + +STR=$'host=central.sonatype.com +user='$sonatype_user' +password='$sonatype_password'' + +echo "$STR" > ~/.ivy2/.credentials + +# import private key first +echo "[INFO] Importing PGP key." +if [ ! -z "$GPG_PRIVATE_KEY" ] && [ -f "$GPG_PRIVATE_KEY" ]; then + # First check if already imported private key + if ! gpg --list-secret-key | grep "$GPG_KEY_ID"; then + gpg --allow-secret-key-import --import "$GPG_PRIVATE_KEY" + fi +fi + +which sbt +if [ $? -ne 0 ] +then + pushd .. + echo "[INFO] sbt is not installed, downloading latest sbt for test and build." + curl -L -o sbt-1.11.4.zip https://github.com/sbt/sbt/releases/download/v1.11.4/sbt-1.11.4.zip + unzip sbt-1.11.4.zip + PATH=$PWD/sbt/bin:$PATH + popd +else + echo "[INFO] Using system installed sbt." +fi +which sbt +sbt version + +echo "[INFO] Checking out snowpark-java-scala @ tag: $github_version_tag." +git checkout $github_version_tag + +if [ "$PUBLISH" = true ]; then + if [ "$SNOWPARK_FIPS" = true ]; then + echo "[INFO] Packaging snowpark-fips @ tag: $github_version_tag." + else + echo "[INFO] Packaging snowpark @ tag: $github_version_tag." + fi + sbt +publishSigned + echo "[INFO] Staged packaged artifacts locally with PGP signing." + sbt sonaUpload + echo "[SUCCESS] Uploaded artifacts to central portal." + echo "[ACTION-REQUIRED] Please log in to Central Portal to publish artifacts: https://central.sonatype.com/" + # TODO: alternatively automate publishing fully +# sbt sonaRelease +# echo "[SUCCESS] Released Snowpark Java-Scala $github_version_tag to Maven." +else + #release to s3 + echo "[INFO] Staging signed artifacts to local ivy2 repository." + rm -rf ~/.ivy2/local/ + sbt +publishLocalSigned + + # SBT will build FIPS version of Snowpark automatically if the environment variable exists. + if [ "$SNOWPARK_FIPS" = true ]; then + S3_JENKINS_URL="s3://sfc-eng-jenkins/repository/snowparkclient-fips/$github_version_tag/" + S3_DATA_URL="s3://sfc-eng-data/client/snowparkclient-fips/releases/$github_version_tag/" + echo "[INFO] Uploading snowpark-fips artifacts to:" + else + S3_JENKINS_URL="s3://sfc-eng-jenkins/repository/snowparkclient/$github_version_tag/" + S3_DATA_URL="s3://sfc-eng-data/client/snowparkclient/releases/$github_version_tag/" + echo "[INFO] Uploading snowpark artifacts to:" + fi + echo "[INFO] - $S3_JENKINS_URL" + echo "[INFO] - $S3_DATA_URL" + + # Remove release folders in s3 for current release version if they already exist due to previously failed release pipeline runs. + echo "[INFO] Deleting $github_version_tag release folders in s3 if they already exist." + aws s3 rm "$S3_JENKINS_URL" --recursive + echo "[INFO] $S3_JENKINS_URL folder deleted if it exists." + aws s3 rm "$S3_DATA_URL" --recursive + echo "[INFO] $S3_DATA_URL folder deleted if it exists." + + # Rename all produced artifacts to include version number (sbt doesn't by default when publishing to local ivy2 repository). + # TODO: BEFORE SNOWPARK v2.12.0, fix the regex in the sed command to not match the 2.12.x or 2.13.x named folder under ~/.ivy2/local/com.snowflake/snowpark_2.1[23]/ + find ~/.ivy2/local -type f -name '*snowpark*' | while read file; do newfile=$(echo "$file" | sed "s/\(2\.1[23]\)\([-\.]\)/\1-${github_version_tag#v}\2/"); mv "$file" "$newfile"; done + + # Generate sha256 checksums for all artifacts produced except .md5, .sha1, and existing .sha256 checksum files. + find ~/.ivy2/local -type f -name '*snowpark*' ! -name '*.md5' ! -name '*.sha1' ! -name '*.sha256' -exec sh -c 'for f; do sha256sum "$f" | awk '"'"'{printf "%s", $1}'"'"' > "$f.sha256"; done' _ {} + + + # Copy all files, flattening the nested structure of the ivy2 repository into the expected structure on s3. + find ~/.ivy2/local -type f -name '*snowpark*' ! -name '*.sha1' -exec aws s3 cp \{\} $S3_JENKINS_URL \; + find ~/.ivy2/local -type f -name '*snowpark*' ! -name '*.sha1' -exec aws s3 cp \{\} $S3_DATA_URL \; + + echo "[SUCCESS] Published Snowpark Java-Scala $github_version_tag artifacts to S3." +fi diff --git a/scripts/deploy-fips.sh b/scripts/deploy-fips.sh new file mode 100755 index 00000000..69d141fe --- /dev/null +++ b/scripts/deploy-fips.sh @@ -0,0 +1,8 @@ +#!/bin/bash -ex +# +# Push Snowpark Java/Scala FIPS build to the public maven repository. +# This script needs to be executed by snowflake jenkins job. +# + +export SNOWPARK_FIPS="true" +source scripts/deploy-common.sh diff --git a/scripts/deploy.sh b/scripts/deploy.sh index cdfa5705..e4715f12 100755 --- a/scripts/deploy.sh +++ b/scripts/deploy.sh @@ -1,121 +1,8 @@ #!/bin/bash -ex # -# Push Snowpark Java/Scala to the public maven repository. +# Push Snowpark Java/Scala build to the public maven repository. # This script needs to be executed by snowflake jenkins job. -# If the SNOWPARK_FIPS environment variable exists when running -# the script, the fips build of the snowpark client will be -# published instead of the regular build. # -if [ -z "$GPG_KEY_ID" ]; then - export GPG_KEY_ID="Snowflake Computing" - echo "[WARN] GPG key ID not specified, using default: $GPG_KEY_ID." -fi - -if [ -z "$GPG_KEY_PASSPHRASE" ]; then - echo "[ERROR] GPG passphrase is not specified for $GPG_KEY_ID!" - exit 1 -fi - -if [ -z "$GPG_PRIVATE_KEY" ]; then - echo "[ERROR] GPG private key file is not specified!" - exit 1 -fi - -if [ -z "$sonatype_user" ]; then - echo "[ERROR] Jenkins sonatype user is not specified!" - exit 1 -fi - -if [ -z "$sonatype_password" ]; then - echo "[ERROR] Jenkins sonatype pwd is not specified!" - exit 1 -fi - -if [ -z "$PUBLISH" ]; then - echo "[ERROR] 'PUBLISH' is not specified!" - exit 1 -fi - -if [ -z "$github_version_tag" ]; then - echo "[ERROR] 'github_version_tag' is not specified!" - exit 1 -fi - -mkdir -p ~/.ivy2 - -STR=$'host=central.sonatype.com -user='$sonatype_user' -password='$sonatype_password'' - -echo "$STR" > ~/.ivy2/.credentials - -# import private key first -echo "[INFO] Importing PGP key." -if [ ! -z "$GPG_PRIVATE_KEY" ] && [ -f "$GPG_PRIVATE_KEY" ]; then - # First check if already imported private key - if ! gpg --list-secret-key | grep "$GPG_KEY_ID"; then - gpg --allow-secret-key-import --import "$GPG_PRIVATE_KEY" - fi -fi - -which sbt -if [ $? -ne 0 ] -then - pushd .. - echo "[INFO] sbt is not installed, downloading latest sbt for test and build." - curl -L -o sbt-1.11.4.zip https://github.com/sbt/sbt/releases/download/v1.11.4/sbt-1.11.4.zip - unzip sbt-1.11.4.zip - PATH=$PWD/sbt/bin:$PATH - popd -else - echo "[INFO] Using system installed sbt." -fi -which sbt -sbt version - -echo "[INFO] Checking out snowpark-java-scala @ tag: $github_version_tag." -git checkout tags/$github_version_tag - -if [ "$PUBLISH" = true ]; then - if [ -v SNOWPARK_FIPS ]; then - echo "[INFO] Packaging snowpark-fips @ tag: $github_version_tag." - else - echo "[INFO] Packaging snowpark @ tag: $github_version_tag." - fi - sbt +publishSigned - echo "[INFO] Staged packaged artifacts locally with PGP signing." - sbt sonaUpload - echo "[SUCCESS] Uploaded artifacts to central portal." - echo "[ACTION-REQUIRED] Please log in to Central Portal to publish artifacts: https://central.sonatype.com/" - # TODO: alternatively automate publishing fully -# sbt sonaRelease -# echo "[SUCCESS] Released Snowpark Java-Scala v$github_version_tag to Maven." -else - #release to s3 - echo "[INFO] Staging signed artifacts to local ivy2 repository." - rm -rf ~/.ivy2/local/ - sbt +publishLocalSigned - - # SBT will build FIPS version of Snowpark automatically if the environment variable exists. - if [ -v SNOWPARK_FIPS ]; then - S3_JENKINS_URL="s3://sfc-eng-jenkins/repository/snowparkclient-fips" - S3_DATA_URL="s3://sfc-eng-data/client/snowparkclient-fips/releases" - echo "[INFO] Uploading snowpark-fips artifacts to:" - else - S3_JENKINS_URL="s3://sfc-eng-jenkins/repository/snowparkclient" - S3_DATA_URL="s3://sfc-eng-data/client/snowparkclient/releases" - echo "[INFO] Uploading snowpark artifacts to:" - fi - echo "[INFO] - $S3_JENKINS_URL/$github_version_tag/" - echo "[INFO] - $S3_DATA_URL/$github_version_tag/" - - # Rename all produced artifacts to include version number (sbt doesn't by default when publishing to local ivy2 repository). - find ~/.ivy2/local -type f -name "*snowpark*" | while read file; do newfile=$(echo "$file" | sed "s/\(2\.1[23]\)\([-\.]\)/\1-${github_version_tag#v}\2/"); mv "$file" "$newfile"; done - - # Copy all files, flattening the nested structure of the ivy2 repository into the expected structure on s3. - find ~/.ivy2/local -type f -name "*snowpark*" -exec aws s3 cp \{\} $S3_JENKINS_URL/$github_version_tag/ \; - find ~/.ivy2/local -type f -name "*snowpark*" -exec aws s3 cp \{\} $S3_DATA_URL/$github_version_tag/ \; - - echo "[SUCCESS] Published Snowpark Java-Scala v$github_version_tag artifacts to S3." -fi +unset SNOWPARK_FIPS +source scripts/deploy-common.sh diff --git a/scripts/utils.sh b/scripts/utils.sh index 942fde49..720775a8 100644 --- a/scripts/utils.sh +++ b/scripts/utils.sh @@ -62,14 +62,22 @@ run_test_suites() { # Avoid failures in subsequent test runs due to an already closed stderr. export DISABLE_REDIRECT_STDERR="" + # Set JVM system property for FIPS test if SNOWPARK_FIPS is true. + if [ "$SNOWPARK_FIPS" = true ]; then + FIPS='-J-DFIPS_TEST=true' + echo "Passing $FIPS to sbt" + else + FIPS='' + fi + # test - sbt clean +compile \ + sbt $FIPS clean +compile \ +JavaAPITests:test \ +NonparallelTests:test \ '++ 2.12.20 OtherTests:testOnly * -- -l SampleDataTest' \ '++ 2.13.16 OtherTests:testOnly * -- -l SampleDataTest' \ '++ 2.12.20 UDFTests:testOnly * -- -l SampleDataTest' \ - '++ 2.13.16 UDFTests:testOnly * -- -l SampleDataTest -l com.snowflake.snowpark.UDFPackageTest' \ + '++ 2.13.16 UDFTests:testOnly * -- -l SampleDataTest -l com.snowflake.snowpark.UDFPackageTest' \ '++ 2.12.20 UDTFTests:testOnly * -- -l SampleDataTest' \ '++ 2.13.16 UDTFTests:testOnly * -- -l SampleDataTest -l com.snowflake.snowpark.UDFPackageTest' \ +SprocTests:test diff --git a/src/main/java/com/snowflake/snowpark_java/CaseExpr.java b/src/main/java/com/snowflake/snowpark_java/CaseExpr.java index 875ae02b..4be9fede 100644 --- a/src/main/java/com/snowflake/snowpark_java/CaseExpr.java +++ b/src/main/java/com/snowflake/snowpark_java/CaseExpr.java @@ -20,25 +20,56 @@ public class CaseExpr extends Column { } /** - * Appends one more WHEN condition to the CASE expression. + * Appends one more WHEN condition to the CASE expression. This method handles any literal value + * and converts it into a `Column` if applies. + * + *

Example: + * + *

{@code
+   * Column result = when(col("age").lt(lit(18)), "Minor")
+   * .when(col("age").lt(lit(65)), "Adult")
+   * .otherwise("Senior");
+   * }
* - * @since 0.12.0 * @param condition The case condition * @param value The result value in the given condition * @return The result case expression + * @since 0.12.0 */ - public CaseExpr when(Column condition, Column value) { - return new CaseExpr(caseExpr.when(condition.toScalaColumn(), value.toScalaColumn())); + public CaseExpr when(Column condition, Object value) { + return new CaseExpr(caseExpr.when(condition.toScalaColumn(), toExpr(value).toScalaColumn())); } /** - * Sets the default result for this CASE expression. + * Sets the default result for this CASE expression. This method handles any literal value and + * converts it into a `Column` if applies. + * + *

Example: + * + *

{@code
+   * Column result = when(col("state").equal(lit("CA")), lit(1000))
+   * .when(col("state").equal(lit("NY")), lit(2000))
+   * .otherwise(1000);
+   * }
* + * @param value The default value, which can be any literal (e.g., String, int, boolean) or a + * `Column`. + * @return The result column. * @since 0.12.0 - * @param value The default value - * @return The result column */ - public Column otherwise(Column value) { - return new Column(caseExpr.otherwise(value.toScalaColumn())); + public Column otherwise(Object value) { + return new Column(caseExpr.otherwise(toExpr(value).toScalaColumn())); + } + + /** + * Converts any value to an Expression. If the value is already a Column, uses its expression + * directly. Otherwise, wraps it with lit() to create a Column expression. + */ + private Column toExpr(Object exp) { + if (exp instanceof Column) { + return ((Column) exp); + } + + return Functions.lit(exp); } } diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index c4644c21..d96d5194 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -903,6 +903,18 @@ public void show() { df.show(); } + /** + * Evaluates this DataFrame and prints out the first ten rows with configurable column truncation. + * + * @param truncate Whether to truncate long column values. If {@code true}, column values longer + * than 50 characters will be truncated with "...". If {@code false}, full column values will + * be displayed regardless of length. + * @since 1.17.0 + */ + public void show(boolean truncate) { + df.show(truncate); + } + /** * Evaluates this DataFrame and prints out the first `''n''` rows. * @@ -913,6 +925,20 @@ public void show(int n) { df.show(n); } + /** + * Evaluates this DataFrame and prints out the first {@code n} rows with configurable column + * truncation. + * + * @param n The number of rows to print out. + * @param truncate Whether to truncate long column values. If {@code true}, column values longer + * than 50 characters will be truncated with "...". If {@code false}, full column values will + * be displayed regardless of length. + * @since 1.17.0 + */ + public void show(int n, boolean truncate) { + df.show(n, truncate); + } + /** * Evaluates this DataFrame and prints out the first `''n''` rows with the specified maximum * number of characters per column. diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 32259e2a..cfc88549 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -1482,6 +1482,51 @@ public static Column concat_ws(Column separator, Column... exprs) { JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(exprs)))); } + /** + * Concatenates two or more strings ignoring any null values. + * + *

Unlike {@link #concat_ws}, this function automatically filters out null values before + * concatenation. + * + *

Examples + * + *

{@code
+   * DataFrame df = session.createDataFrame(
+   *   new Row[] {
+   *     Row.create("Hello", "World", null),
+   *     Row.create(null, null, null),
+   *     Row.create("Hello", null, null)
+   *   },
+   *   StructType.create(
+   *     new StructField("A", DataTypes.StringType),
+   *     new StructField("B", DataTypes.StringType),
+   *     new StructField("C", DataTypes.StringType)
+   *   )
+   * );
+   *
+   * df.select(
+   *   concat_ws_ignore_nulls(" | ", col("A"), col("B"), col("C")).as("concat_ws_ignore_nulls")
+   * ).show();
+   * ----------------------------
+   * |"CONCAT_WS_IGNORE_NULLS"  |
+   * ----------------------------
+   * |Hello | World             |
+   * |                          |
+   * |Hello                     |
+   * ----------------------------
+   * }
+ * + * @param separator A string literal used as the separator between concatenated values. + * @param exprs The columns to be concatenated. + * @return A Column containing the concatenated values with null values filtered out. + * @since 1.17.0 + */ + public static Column concat_ws_ignore_nulls(String separator, Column... exprs) { + return new Column( + com.snowflake.snowpark.functions.concat_ws_ignore_nulls( + separator, JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(exprs)))); + } + /** * Returns the input string with the first letter of each word in uppercase and the subsequent * letters in lowercase. @@ -3579,6 +3624,95 @@ public static Column array_contains(Column variant, Column array) { variant.toScalaColumn(), array.toScalaColumn())); } + /** + * Flattens an array of arrays into a single array, removing only one level of nesting. + * + *

Examples + * + *

Example 1: Flattens a two-level nested array into a single array of elements. + * + *

{@code
+   * DataFrame df = session.createDataFrame(
+   *   new Row[] {
+   *     Row.create((Object) new int[][] {{1, 2}, {3, 4}}),
+   *     Row.create((Object) new int[][] {{5, 6, 7}, {8}}),
+   *     Row.create((Object) new int[][] {{}, {9, 10}}),
+   *   },
+   *   StructType.create(new StructField(
+   *     "a",
+   *     DataTypes.createArrayType(DataTypes.createArrayType(DataTypes.IntegerType))
+   *   ))
+   * );
+   *
+   * df.select(array_flatten(col("a"))).show();
+   * --------------------------
+   * |"ARRAY_FLATTEN(""A"")"  |
+   * --------------------------
+   * |[                       |
+   * |  1,                    |
+   * |  2,                    |
+   * |  3,                    |
+   * |  4                     |
+   * |]                       |
+   * |[                       |
+   * |  5,                    |
+   * |  6,                    |
+   * |  7,                    |
+   * |  8                     |
+   * |]                       |
+   * |[                       |
+   * |  9,                    |
+   * |  10                    |
+   * |]                       |
+   * --------------------------
+   * }
+ * + *

Example 2: Flattens only one level of a three-level nested array. + * + *

{@code
+   * DataFrame df = session.createDataFrame(
+   *   new Row[] {
+   *     Row.create((Object) new int[][][] {{{1, 2}, {3}}, {{4, 5}}}),
+   *   },
+   *   StructType.create(new StructField(
+   *     "a",
+   *     DataTypes.createArrayType(DataTypes.createArrayType(DataTypes.createArrayType(DataTypes.IntegerType)))
+   *   ))
+   * );
+   *
+   * df.select(array_flatten(col("a"))).show();
+   * --------------------------
+   * |"ARRAY_FLATTEN(""A"")"  |
+   * --------------------------
+   * |[                       |
+   * |  [                     |
+   * |    1,                  |
+   * |    2                   |
+   * |  ],                    |
+   * |  [                     |
+   * |    3                   |
+   * |  ],                    |
+   * |  [                     |
+   * |    4,                  |
+   * |    5                   |
+   * |  ]                     |
+   * |]                       |
+   * --------------------------
+   * }
+ * + * @param array Column containing the array of arrays to flatten. + * + * + * @return A column containing the flattened array. + * @since 1.17.0 + */ + public static Column array_flatten(Column array) { + return new Column(com.snowflake.snowpark.functions.array_flatten(array.toScalaColumn())); + } + /** * Returns an ARRAY containing all elements from the source ARRAY as well as the new element. * @@ -4191,19 +4325,20 @@ public static Column get_path(Column col, Column path) { *
{@code
    * import com.snowflake.snowpark_java.Functions;
    * df.select(Functions
-   *     .when(df.col("col").is_null, Functions.lit(1))
-   *     .when(df.col("col").equal_to(Functions.lit(1)), Functions.lit(6))
+   *     .when(df.col("col").is_null, 1)
+   *     .when(df.col("col").equal_to(Functions.lit(1)), 6)
    *     .otherwise(Functions.lit(7)));
    * }
* - * @since 0.12.0 * @param condition The condition * @param value The result value * @return The result column + * @since 0.12.0 */ - public static CaseExpr when(Column condition, Column value) { + public static CaseExpr when(Column condition, Object value) { return new CaseExpr( - com.snowflake.snowpark.functions.when(condition.toScalaColumn(), value.toScalaColumn())); + com.snowflake.snowpark.functions.when( + condition.toScalaColumn(), toExpr(value).toScalaColumn())); } /** @@ -5904,4 +6039,16 @@ private static UserDefinedFunction userDefinedFunction( String funcName, Supplier func) { return javaUDF("Functions", funcName, "", "", func); } + + /** + * Converts any value to an Expression. If the value is already a Column, uses its expression + * directly. Otherwise, wraps it with lit() to create a Column expression. + */ + private static Column toExpr(Object exp) { + if (exp instanceof Column) { + return ((Column) exp); + } + + return Functions.lit(exp); + } } diff --git a/src/main/scala/com/snowflake/snowpark/Column.scala b/src/main/scala/com/snowflake/snowpark/Column.scala index 64434330..2fe81efb 100644 --- a/src/main/scala/com/snowflake/snowpark/Column.scala +++ b/src/main/scala/com/snowflake/snowpark/Column.scala @@ -764,18 +764,54 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)]) /** * Appends one more WHEN condition to the CASE expression. * + * This method handles any literal value and converts it into a `Column`. + * + * ===Example=== + * {{{ + * val df = session.sql("SELECT * FROM values (10), (25), (65), (70) as T(age)") + * val result = df.select( + * when(col("age") < lit(18), "Minor") + * .when(col("age") < lit(65), lit("Adult")) + * .otherwise("Senior") + * ) + * // The second when condition will be "Adult" for rows where age >= 18 and age < 65 + * }}} + * + * @param condition + * The case condition. + * @param value + * The result value, which can be any literal (e.g., String, Int, Boolean) or a `Column`. + * @return + * The result case expression. * @since 0.2.0 */ - def when(condition: Column, value: Column): CaseExpr = - new CaseExpr(branches :+ ((condition.expr, value.expr))) + def when(condition: Column, value: Any): CaseExpr = + new CaseExpr(branches :+ (condition.expr, toExpr(value))) /** * Sets the default result for this CASE expression. * + * This method handles any literal value and converts it into a `Column` using `lit()`. + * + * ===Example=== + * {{{ + * val df = session.sql("SELECT * FROM values (10), (25), (65), (70) as T(age)") + * val result = df.select( + * when(col("age") < lit(18), "Minor") + * .when(col("age") < lit(65), lit("Adult")) + * .otherwise("Senior") + * ) + * // The age_category column will be "Senior" for rows where age >= 65 + * }}} + * + * @param value + * The default value, which can be any literal (e.g., String, Int, Boolean) or a `Column`. + * @return + * The result column. * @since 0.2.0 */ - def otherwise(value: Column): Column = withExpr { - CaseWhen(branches, Option(value.expr)) + def otherwise(value: Any): Column = withExpr { + CaseWhen(branches, Option(toExpr(value))) } /** @@ -783,5 +819,14 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)]) * * @since 0.2.0 */ - def `else`(value: Column): Column = otherwise(value) + def `else`(value: Any): Column = otherwise(value) + + /** + * Converts any value to an Expression. If the value is already a Column, uses its expression + * directly. Otherwise, wraps it with lit() to create a Column expression. + */ + private def toExpr(exp: Any) = exp match { + case c: Column => c.expr + case _ => lit(exp).expr + } } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 1e9ad03a..0c381f5c 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -2616,6 +2616,35 @@ class DataFrame private[snowpark] ( show(n, 50) } + /** + * Evaluates this DataFrame and prints out the first `n` rows. + * + * @param n + * The number of rows to print out. + * @param truncate + * Whether to truncate long column values. If `true`, column values longer than 50 characters + * will be truncated. If `false`, full column values will be displayed. + * @group actions + * @since 1.17.0 + */ + def show(n: Int, truncate: Boolean): Unit = action("show") { + val maxWidth = if (truncate) 50 else 0 + this.show(n, maxWidth) + } + + /** + * Evaluates this DataFrame and prints out the first 10 rows. + * + * @param truncate + * Whether to truncate long column values. If `true`, column values longer than 50 characters + * will be truncated. If `false`, full column values will be displayed. + * @group actions + * @since 1.17.0 + */ + def show(truncate: Boolean): Unit = action("show") { + this.show(10, truncate) + } + /** * Evaluates this DataFrame and prints out the first `''n''` rows with the specified maximum * number of characters per column. @@ -2714,7 +2743,7 @@ class DataFrame private[snowpark] ( if (colWidth(index) < str.length) { colWidth(index) = str.length } - if (colWidth(index) > maxWidth) { + if (maxWidth != 0 && colWidth(index) > maxWidth) { colWidth(index) = maxWidth } }) @@ -2746,7 +2775,7 @@ class DataFrame private[snowpark] ( row .zip(colWidth) .map { case (str, size) => - if (str.length > maxWidth) { + if (maxWidth != 0 && str.length > maxWidth) { // if truncated, add ... to the end (str.take(maxWidth - 3) + "...").padTo(size, " ").mkString } else { diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index f1994fa7..8acf20eb 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -1307,6 +1307,54 @@ object functions { builtin("concat_ws")(args: _*) } + /** + * Concatenates two or more strings ignoring any null values. + * + * Unlike [[concat_ws]], this function automatically filters out null values before concatenation. + * + * '''Examples''' + * + * {{{ + * val df = session.createDataFrame( + * Seq( + * Row("Hello", "World", null), + * Row(null, null, null), + * Row("Hello", null, null), + * ), + * StructType( + * StructField("A", StringType), + * StructField("B", StringType), + * StructField("C", StringType), + * ) + * ) + * + * df.select( + * concat_ws_ignore_nulls(" | ", col("A"), col("B"), col("C")).as("concat_ws_ignore_nulls") + * ).show() + * ---------------------------- + * |"CONCAT_WS_IGNORE_NULLS" | + * ---------------------------- + * |Hello | World | + * | | + * |Hello | + * ---------------------------- + * }}} + * + * @param separator + * A string literal used as the separator between concatenated values. + * @param exprs + * The columns to be concatenated. + * @return + * A Column containing the concatenated values with null values filtered out. + * @group str_func + * @since 1.17.0 + */ + def concat_ws_ignore_nulls(separator: String, exprs: Column*): Column = { + val stringArrays = exprs.map(_.cast(ArrayType(StringType))) + val nonNullArray = array_compact(array_flatten(array_construct_compact(stringArrays: _*))) + array_to_string(nonNullArray, lit(separator)) + } + /** * Returns the input string with the first letter of each word in uppercase and the subsequent * letters in lowercase. @@ -2954,6 +3002,81 @@ object functions { builtin("array_contains")(variant, array) } + /** + * Flattens an array of arrays into a single array, removing only one level of nesting. + * + * '''Examples''' + * + * Example 1: Flattens a two-level nested array into a single array of elements. + * + * {{{ + * val df = Seq( + * Array(Array(1, 2), Array(3, 4)), + * Array(Array(5, 6, 7), Array(8)), + * Array(Array.empty[Int], Array(9, 10)), + * ).toDF("a") + * + * df.select(array_flatten(col("a"))).show() + * -------------------------- + * |"ARRAY_FLATTEN(""A"")" | + * -------------------------- + * |[ | + * | 1, | + * | 2, | + * | 3, | + * | 4 | + * |] | + * |[ | + * | 5, | + * | 6, | + * | 7, | + * | 8 | + * |] | + * |[ | + * | 9, | + * | 10 | + * |] | + * -------------------------- + * }}} + * + * Example 2: Flattens only one level of a three-level nested array. + * + * {{{ + * val df = Seq( + * Array(Array(Array(1, 2), Array(3)), Array(Array(4, 5))) + * ).toDF("a") + * + * df.select(array_flatten(col("a"))).show() + * -------------------------- + * |"ARRAY_FLATTEN(""A"")" | + * -------------------------- + * |[ | + * | [ | + * | 1, | + * | 2 | + * | ], | + * | [ | + * | 3 | + * | ], | + * | [ | + * | 4, | + * | 5 | + * | ] | + * |] | + * -------------------------- + * }}} + * + * @param array + * Column containing the array of arrays to flatten. + * - If any element of `array` is not an ARRAY, the function throws an error. + * - If `array` is NULL, the function returns NULL. + * @return + * A column containing the flattened array. + * @group semi_func + * @since 1.17.0 + */ + def array_flatten(array: Column): Column = builtin("array_flatten")(array) + /** * Returns an ARRAY containing all elements from the source ARRAY as well as the new element. * @@ -3471,21 +3594,31 @@ object functions { * Works like a cascading if-then-else statement. A series of conditions are evaluated in * sequence. When a condition evaluates to TRUE, the evaluation stops and the associated result * (after THEN) is returned. If none of the conditions evaluate to TRUE, then the result after the - * optional OTHERWISE is returned, if present; otherwise NULL is returned. For Example: + * optional OTHERWISE is returned, if present; otherwise NULL is returned. + * + * ===Example=== * {{{ - * import functions._ - * df.select( - * when(col("col").is_null, lit(1)) - * .when(col("col") === 1, lit(2)) - * .otherwise(lit(3)) - * ) + * import functions._ + * val df = session.sql("SELECT * FROM values (null, 5), (1, 10), (2, 15) as T(col, numeric_col)") + * val result = df.select( + * when(col("col").is_null, lit(1)) + * .when(col("col") === 1, lit(2)) + * .when(col("col") === 1, col("numeric_col") * 0.10) + * .otherwise(lit(3)) + * ) * }}} * + * @param condition + * The case condition. + * @param value + * The result value, which can be any literal (e.g., String, Int, Boolean) or a `Column`. + * @return + * The result case expression. * @group con_func * @since 0.2.0 */ - def when(condition: Column, value: Column): CaseExpr = - new CaseExpr(Seq((condition.expr, value.expr))) + def when(condition: Column, value: Any): CaseExpr = + new CaseExpr(Seq((condition.expr, toExpr(value)))) /** * Returns one of two specified expressions, depending on a condition. @@ -5549,4 +5682,12 @@ object functions { "")(func) } + /** + * Converts any value to an Expression. If the value is already a Column, uses its expression + * directly. Otherwise, wraps it with lit() to create a Column expression. + */ + private def toExpr(exp: Any) = exp match { + case c: Column => c.expr + case _ => lit(exp).expr + } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaColumnSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaColumnSuite.java index 3aa04cfe..da7bf5ad 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaColumnSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaColumnSuite.java @@ -281,6 +281,30 @@ public void caseWhen() { Row.create((Object) null), Row.create(5) }); + + // Handling no column type values + checkAnswer( + df.select( + Functions.when(df.col("a").is_null(), 5) + .when(df.col("a").equal_to(Functions.lit(1)), 6) + .otherwise(7) + .as("a")), + new Row[] {Row.create(5), Row.create(7), Row.create(6), Row.create(7), Row.create(5)}); + + // Handling null values + checkAnswer( + df.select( + Functions.when(df.col("a").is_null(), null) + .when(df.col("a").equal_to(Functions.lit(1)), null) + .otherwise(null) + .as("a")), + new Row[] { + Row.create((Object) null), + Row.create((Object) null), + Row.create((Object) null), + Row.create((Object) null), + Row.create((Object) null) + }); } @Test diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 72e477b5..71c2c86e 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -1,9 +1,17 @@ package com.snowflake.snowpark_test; +import static org.junit.Assert.assertThrows; + import com.snowflake.snowpark_java.*; +import com.snowflake.snowpark_java.types.DataTypes; +import com.snowflake.snowpark_java.types.StructField; +import com.snowflake.snowpark_java.types.StructType; import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.util.Arrays; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.junit.Assert; import org.junit.Test; public class JavaFunctionSuite extends TestBase { @@ -828,6 +836,84 @@ public void concat_ws() { df.select(Functions.concat_ws(Functions.lit(","), df.col("a"), df.col("b"))), expected); } + @Test + public void concat_ws_ignore_nulls() { + DataFrame df = + getSession() + .createDataFrame( + new Row[] { + Row.create(new String[] {"a", "b"}, new String[] {"c"}, "d", "e", 1, 2), + Row.create( + new String[] {"Hello", null, "world"}, + new String[] {null, "!", null}, + "bye", + "world", + 3, + null), + Row.create(new String[] {null, null}, new String[] {"R", "H"}, null, "TD", 4, 5), + Row.create(null, new String[] {null}, null, null, null, null), + Row.create(null, null, null, null, null, null), + }, + StructType.create( + new StructField("arr1", DataTypes.createArrayType(DataTypes.StringType)), + new StructField("arr2", DataTypes.createArrayType(DataTypes.StringType)), + new StructField("str1", DataTypes.StringType), + new StructField("str2", DataTypes.StringType), + new StructField("int1", DataTypes.IntegerType), + new StructField("int2", DataTypes.IntegerType))); + + Column[] columns = + Arrays.stream(df.schema().fieldNames()).map(Functions::col).toArray(Column[]::new); + + // Single character delimiter + checkAnswer( + df.select(Functions.concat_ws_ignore_nulls(",", columns)), + new Row[] { + Row.create("a,b,c,d,e,1,2"), + Row.create("Hello,world,!,bye,world,3"), + Row.create("R,H,TD,4,5"), + Row.create(""), + Row.create("") + }); + + // Multi-character delimiter + checkAnswer( + df.select(Functions.concat_ws_ignore_nulls(" : ", columns)), + new Row[] { + Row.create("a : b : c : d : e : 1 : 2"), + Row.create("Hello : world : ! : bye : world : 3"), + Row.create("R : H : TD : 4 : 5"), + Row.create(""), + Row.create("") + }); + + DataFrame df2 = + getSession() + .createDataFrame( + new Row[] { + Row.create(Date.valueOf("2021-12-21")), Row.create(Date.valueOf("1969-12-31")) + }, + StructType.create(new StructField("YearMonth", DataTypes.DateType))); + + checkAnswer( + df2.select( + Functions.concat_ws_ignore_nulls( + "-", + Functions.year(Functions.col("YearMonth")), + Functions.month(Functions.col("YearMonth")))), + new Row[] {Row.create("2021-12"), Row.create("1969-12")}); + + // Resulting column should allow to define an alias + checkAnswer( + df2.select( + Functions.concat_ws_ignore_nulls( + "-", + Functions.year(Functions.col("YearMonth")), + Functions.month(Functions.col("YearMonth"))) + .alias("YEAR_MONTH")), + new Row[] {Row.create("2021-12"), Row.create("1969-12")}); + } + @Test public void initcap_length_lower_upper() { DataFrame df = getSession().sql("select * from values('asdFg'),('qqq'),('Qw') as T(a)"); @@ -2446,6 +2532,36 @@ public void array_slice() { df.select(Functions.array_slice(df.col("arr1"), df.col("d"), df.col("e"))), expected); } + @Test + public void array_flatten() { + // Flattening a 2D array + DataFrame df1 = + getSession().sql("SELECT [[1, 2, 3], [], [4], [5, NULL, PARSE_JSON('null')]] AS A"); + checkAnswer( + df1.select(Functions.array_flatten(Functions.col("A"))), + new Row[] {Row.create("[\n 1,\n 2,\n 3,\n 4,\n 5,\n undefined,\n null\n]")}); + + // Flattening a 3D array + DataFrame df2 = getSession().sql("SELECT [[[1, 2], [3]]] AS A"); + checkAnswer( + df2.select(Functions.array_flatten(Functions.col("A"))), + new Row[] {Row.create("[\n [\n 1,\n 2\n ],\n [\n 3\n ]\n]")}); + + // Flattening a null array + DataFrame df3 = getSession().sql("SELECT NULL::ARRAY AS A"); + checkAnswer( + df3.select(Functions.array_flatten(Functions.col("A"))), + new Row[] {Row.create((Object) null)}); + + // Flattening an array with non-array elements + DataFrame df4 = getSession().sql("SELECT [1, 2, 3] AS A"); + SnowflakeSQLException exception = + assertThrows( + SnowflakeSQLException.class, + () -> df4.select(Functions.array_flatten(Functions.col("A"))).collect()); + Assert.assertTrue(exception.getMessage().contains("not an array")); + } + @Test public void array_to_string() { DataFrame df = diff --git a/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java b/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java index 0acefee3..8d84c314 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java @@ -47,6 +47,14 @@ public void show() { checkSpan("snow.snowpark.DataFrame", "show"); df.show(1, 100); checkSpan("snow.snowpark.DataFrame", "show"); + df.show(true); + checkSpan("snow.snowpark.DataFrame", "show"); + df.show(false); + checkSpan("snow.snowpark.DataFrame", "show"); + df.show(1, true); + checkSpan("snow.snowpark.DataFrame", "show"); + df.show(1, false); + checkSpan("snow.snowpark.DataFrame", "show"); } @Test diff --git a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala index 44027c6a..abb091f9 100644 --- a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala @@ -492,7 +492,7 @@ class UtilsSuite extends SNTestBase { } test("Utils.version matches sbt build") { - assert(Utils.Version == "1.17.0-SNAPSHOT") + assert(Utils.Version == "1.17.0") } test("Utils.retrySleepTimeInMS") { diff --git a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala index 63034dde..2e7bc2cf 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala @@ -513,6 +513,51 @@ class ColumnSuite extends TestData { .as("a")), Seq(Row(5), Row(7), Row(6), Row(7), Row(5))) + // no column typed value + checkAnswer( + nullData1.select( + functions + .when(col("a").is_null, lit(1)) + .when(col("a") === 1, col("a") / 2) + .when(col("a") === 2, col("a") * 2) + .when(col("a") === 3, pow(col("a"), 2)) + .as("a")), + Seq(Row(0.5), Row(1.0), Row(1.0), Row(4.0), Row(9.0))) + + checkAnswer( + nullData1.select( + functions + .when(col("a").is_null, "null_value") + .when(col("a") <= 2, "lower or equal than two") + .when(col("a") >= 3, "greater than two") + .as("a")), + Seq( + Row("greater than two"), + Row("lower or equal than two"), + Row("lower or equal than two"), + Row("null_value"), + Row("null_value"))) + + // No column otherwise + checkAnswer( + nullData1.select( + functions + .when(col("a").is_null, lit(5)) + .when(col("a") === 1, lit(6)) + .otherwise(7) + .as("a")), + Seq(Row(5), Row(7), Row(6), Row(7), Row(5))) + + // Handling nulls + checkAnswer( + nullData1.select( + functions + .when(col("a").is_null, null) + .when(col("a") === 1, null) + .otherwise(null) + .as("a")), + Seq(Row(null), Row(null), Row(null), Row(null), Row(null))) + // empty otherwise checkAnswer( nullData1.select( diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala index 1a930903..20908b33 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala @@ -83,6 +83,76 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { |""".stripMargin) } + test("show not truncated") { + // run show function, make sure no error reported + val df = Seq( + "Short Sample", + "Exceeding Maximum Characters Length Row Value To Evaluate Truncated Results").toDF("Column") + df.show(false) + + assert( + getShowString(df, 10, 0) == + """------------------------------------------------------------------------------- + ||"COLUMN" | + |------------------------------------------------------------------------------- + ||Short Sample | + ||Exceeding Maximum Characters Length Row Value To Evaluate Truncated Results | + |------------------------------------------------------------------------------- + |""".stripMargin) + } + + test("show truncated") { + // run show function, make sure no error reported + val df = Seq( + "Short Sample", + "Exceeding Maximum Characters Length Row Value To Evaluate Truncated Results").toDF("Column") + df.show(true) + + assert( + getShowString(df, 10) == + """------------------------------------------------------ + ||"COLUMN" | + |------------------------------------------------------ + ||Short Sample | + ||Exceeding Maximum Characters Length Row Value T... | + |------------------------------------------------------ + |""".stripMargin) + } + + test("show not truncated limited rows") { + // run show function, make sure no error reported + val df = Seq( + "Exceeding Maximum Characters Length Row Value To Evaluate Truncated Results", + "Short Sample").toDF("Column") + df.show(1, false) + + assert( + getShowString(df, 1, 0) == + """------------------------------------------------------------------------------- + ||"COLUMN" | + |------------------------------------------------------------------------------- + ||Exceeding Maximum Characters Length Row Value To Evaluate Truncated Results | + |------------------------------------------------------------------------------- + |""".stripMargin) + } + + test("show truncated limited rows") { + // run show function, make sure no error reported + val df = Seq( + "Exceeding Maximum Characters Length Row Value To Evaluate Truncated Results", + "Short Sample").toDF("Column") + df.show(1, true) + + assert( + getShowString(df, 1) == + """------------------------------------------------------ + ||"COLUMN" | + |------------------------------------------------------ + ||Exceeding Maximum Characters Length Row Value T... | + |------------------------------------------------------ + |""".stripMargin) + } + test("show with null data") { // run show function, make sure no error reported val df = Seq((1, null), (2, "NotNull")).toDF("a", "b") diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index cac76914..4e8e31ea 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -413,6 +413,60 @@ trait FunctionSuite extends TestData { Seq(Row("test1,a"), Row("test2,b"), Row("test3,c"))) } + test("concat_ws_ignore_nulls") { + val df = session.createDataFrame( + Seq( + Row(Array("a", "b"), Array("c"), "d", "e", 1, 2), + Row(Array("Hello", null, "world"), Array(null, "!", null), "bye", "world", 3, null), + Row(Array(null, null), Array("R", "H"), null, "TD", 4, 5), + Row(null, Array(null), null, null, null, null), + Row(null, null, null, null, null, null)), + StructType( + StructField("arr1", ArrayType(StringType)), + StructField("arr2", ArrayType(StringType)), + StructField("str1", StringType), + StructField("str2", StringType), + StructField("int1", IntegerType), + StructField("int2", IntegerType))) + + val columns = df.schema.map(field => col(field.name)) + + // Single character delimiter + checkAnswer( + df.select(concat_ws_ignore_nulls(",", columns: _*)), + Seq( + Row("a,b,c,d,e,1,2"), + Row("Hello,world,!,bye,world,3"), + Row("R,H,TD,4,5"), + Row(""), + Row(""))) + + // Multi-character delimiter + checkAnswer( + df.select(concat_ws_ignore_nulls(" : ", columns: _*)), + Seq( + Row("a : b : c : d : e : 1 : 2"), + Row("Hello : world : ! : bye : world : 3"), + Row("R : H : TD : 4 : 5"), + Row(""), + Row(""))) + + val df2 = session.createDataFrame( + Seq(Row(Date.valueOf("2021-12-21")), Row(Date.valueOf("1969-12-31"))), + StructType(StructField("YearMonth", DateType))) + + checkAnswer( + df2.select(concat_ws_ignore_nulls("-", year(col("YearMonth")), month(col("YearMonth")))), + Seq(Row("2021-12"), Row("1969-12"))) + + // Resulting column should allow to define an alias + checkAnswer( + df2.select( + concat_ws_ignore_nulls("-", year(col("YearMonth")), month(col("YearMonth"))) + .alias("YEAR_MONTH")), + Seq(Row("2021-12"), Row("1969-12"))) + } + test("initcap length lower upper") { checkAnswer( string2.select(initcap(col("A")), length(col("A")), lower(col("A")), upper(col("A"))), @@ -1579,6 +1633,31 @@ trait FunctionSuite extends TestData { Seq(Row("[\n 2\n]"), Row("[\n 5\n]"), Row("[\n 6,\n 7\n]"))) } + test("array_flatten") { + // Flattening a 2D array + val df1 = session.sql("SELECT [[1, 2, 3], [], [4], [5, NULL, PARSE_JSON('null')]] AS A") + checkAnswer( + df1.select(array_flatten(col("A"))), + Seq(Row("[\n 1,\n 2,\n 3,\n 4,\n 5,\n undefined,\n null\n]"))) + + // Flattening a 3D array + val df2 = session.sql("SELECT [[[1, 2], [3]]] AS A") + checkAnswer( + df2.select(array_flatten(col("A"))), + Seq(Row("[\n [\n 1,\n 2\n ],\n [\n 3\n ]\n]"))) + + // Flattening a null array + val df3 = session.sql("SELECT NULL::ARRAY AS A") + checkAnswer(df3.select(array_flatten(col("A"))), Seq(Row(null))) + + // Flattening an array with non-array elements + val df4 = session.sql("SELECT [1, 2, 3] AS A") + val exception = intercept[SnowflakeSQLException] { + df4.select(array_flatten(col("A"))).collect() + } + assert(exception.getMessage.contains("not an array")) + } + test("array_to_string") { checkAnswer( array3.select(array_to_string(col("arr1"), col("f"))),