This is a classic spam classification example using MLLib naive Bayes. You can fetch the dataset from UCI’s data repo.

The whole approach is straightforward and very much like one would approach it via sklearn.

As usual, let’s start with initializing Spark.

import findspark, os
from pyspark.sql import SparkSession
from pyspark import SparkContext

os.environ["JAVA_HOME"]="/Library/Java/JavaVirtualMachines/jdk1.8.0_202.jdk/Contents/Home"

print(findspark.find())
findspark.init()

sc = SparkContext.getOrCreate()
spark = SparkSession.Builder().appName('Spam').getOrCreate()

Read the data and print the schema

df = spark.read.csv('SMSSpamCollection.csv', inferSchema = True, sep = '\t')
df.printSchema()

which gives

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)

The CSV does not have a header, hence the ugly names. Let’s change this:

df = df.withColumnRenamed('_c0', 'class').withColumnRenamed('_c1', 'text')
df.show()

outputting the first 20 rows

+-----+--------------------+
|class|                text|
+-----+--------------------+
|  ham|Go until jurong p...|
|  ham|Ok lar... Joking ...|
| spam|Free entry in 2 a...|
|  ham|U dun say so earl...|
|  ham|Nah I don't think...|
| spam|FreeMsg Hey there...|
|  ham|Even my brother i...|
|  ham|As per your reque...|
| spam|WINNER!! As a val...|
| spam|Had your mobile 1...|
|  ham|I'm gonna be home...|
| spam|SIX chances to wi...|
| spam|URGENT! You have ...|
|  ham|I've been searchi...|
|  ham|I HAVE A DATE ON ...|
| spam|XXXMobileMovieClu...|
|  ham|Oh k...i'm watchi...|
|  ham|Eh u remember how...|
|  ham|Fine if that’s th...|
| spam|England v Macedon...|
+-----+--------------------+
only showing top 20 rows

Looking at the length of the text we can see that spam is on average longer than ham:

from pyspark.sql.functions import length

df = df.withColumn('length', length(df['text']))
df.show(3)
df.groupBy('class').mean().show()

resulting in

+-----+--------------------+------+
|class|                text|length|
+-----+--------------------+------+
|  ham|Go until jurong p...|   111|
|  ham|Ok lar... Joking ...|    29|
| spam|Free entry in 2 a...|   155|
+-----+--------------------+------+
only showing top 3 rows

+-----+-----------------+
|class|      avg(length)|
+-----+-----------------+
|  ham|71.45431945307645|
| spam|138.6706827309237|
+-----+-----------------+

So, we’ll take the length into account when setting up the learning pipeline.
First, we need to clean the text a bit:

from pyspark.ml.feature import CountVectorizer, Tokenizer, StopWordsRemover, IDF, StringIndexer
tokenizer = Tokenizer(inputCol = 'text', outputCol = 'tokens')
stop_remove = StopWordsRemover(inputCol = 'tokens', outputCol = 'stop_token')
count_vec = CountVectorizer(inputCol = 'stop_token', outputCol = 'c_vec')
idf = IDF(inputCol = 'c_vec', outputCol = 'tf_idf')
ham_spam_to_numeric = StringIndexer(inputCol = 'class', outputCol = 'label')

The VectorAssembler is the way you hand over the features to a ML algorithm:

from pyspark.ml.feature import VectorAssembler
clean_up = VectorAssembler(inputCols = ['tf_idf', 'length'], outputCol = 'features')

The frame with the transformed fields is thus

from pyspark.ml import Pipeline
pipeline = Pipeline(stages=[ham_spam_to_numeric, tokenizer, stop_remove, count_vec, idf, clean_up])
cleaner = pipeline.fit(df)
clean_df = cleaner.transform(df)
clean_df.show(3)

showing us

+-----+--------------------+------+-----+--------------------+--------------------+--------------------+--------------------+--------------------+
|class|                text|length|label|              tokens|          stop_token|               c_vec|              tf_idf|            features|
+-----+--------------------+------+-----+--------------------+--------------------+--------------------+--------------------+--------------------+
|  ham|Go until jurong p...|   111|  0.0|[go, until, juron...|[go, jurong, poin...|(13423,[7,11,31,6...|(13423,[7,11,31,6...|(13424,[7,11,31,6...|
|  ham|Ok lar... Joking ...|    29|  0.0|[ok, lar..., joki...|[ok, lar..., joki...|(13423,[0,24,297,...|(13423,[0,24,297,...|(13424,[0,24,297,...|
| spam|Free entry in 2 a...|   155|  1.0|[free, entry, in,...|[free, entry, 2, ...|(13423,[2,13,19,3...|(13423,[2,13,19,3...|(13424,[2,13,19,3...|
+-----+--------------------+------+-----+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 3 rows

We need only two fields for the training

clean_df = clean_df.select('label', 'features')
clean_df.show(3)



+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(13424,[7,11,31,6...|
|  0.0|(13424,[0,24,297,...|
|  1.0|(13424,[2,13,19,3...|
+-----+--------------------+
only showing top 3 rows

For example

clean_df.take(1)[0][8]

gives us

SparseVector(13424, {7: 3.1126, 11: 3.2055, 31: 3.822, 61: 4.2072, 72: 4.322, 344: 5.4072, 625: 5.918, 731: 6.1411, 1409: 6.6801, 1598: 6.8343, 4485: 7.5274, 6440: 7.9329, 8092: 7.9329, 8838: 7.9329, 11344: 7.9329, 12979: 7.9329, 13423: 111.0})

The naive Bayes algorithm is an easy and often used approach for spam classification

from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()

Since we don’t have a separate validation frame we’ll split the given training set

train, test = clean_df.randomSplit([0.7, 0.3])

Let’s see what the training and testing gives

spam_detector = nb.fit(train)
predictions = spam_detector.transform(test)
predictions.show(3)

shows the first three predictions

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(13424,[0,1,3,9,1...|[-571.90621659279...|[1.0,1.2511567713...|       0.0|
|  0.0|(13424,[0,1,7,8,1...|[-873.98191321598...|[1.0,1.6496395700...|       0.0|
|  0.0|(13424,[0,1,7,8,1...|[-1156.3332952900...|[1.0,3.1146776343...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 3 rows

The accuracy can be obtained like so

from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator()
print("Test Accuracy: " + str(evaluator.evaluate(predictions, {evaluator.metricName: "accuracy"})))


Test Accuracy: 0.9134673979280926

Which is not bad considering the unrefined approach we have taken.

spark.stop()