State Transformation in Spark

Aditya Pimparkar
6 min readAug 23, 2020

Spark Streaming is able to handle state-based operations, i.e. operations containing a state susceptible to be modified in subsequent batches of data. Stateful transformation is a particular property of spark streaming that enables us to maintain state between micro batches. In other words, it helps in maintaining the state across a period of time which can be as long as an entire session of streaming jobs. Thus, it allows us to perform sessionization of our data. This is achieved by creating checkpoints on streaming applications.

Stateful Transformation In Spark

A simple use case of state transformation is tracking of user activity. We may want to emit an alert based on user inactivity (login expiry on social media platforms) or maintain user sessions over definite or indefinite time and persist those sessions for post analysis. Both of these scenarios require customized processing.
Structured Streaming APIs offer a set of APIs to handle these cases: mapGroupsWithState and flatMapGroupsWithState. mapGroupsWithState can operate on groups and outputs only a single result row for each group, whereas flatMapGroupsWithState can emit a single row or multiple rows of results per group.

In this article we will try to implement a finite state machine using mapGroupsWithState API.

State Diagram

We will try to implement the above state diagram. Set of states and transition function (rules) have been explained below:-
1. Four states have been used in the example.
2. State “first” is the starting state.
3. State “fourth” is the final state.
4. Rules for inputs (Transition function):-
4.1. ‘a’:- state_measure = state_measure + input_measure
4.2. ‘b’:- state_measure = state_measure — input_measure
4.3. ‘c’:- state_measure = state_measure
5. Only the keys reaching the final state are shown in the output.

The input in this example will be a csv file. A sample of the input has been shown below:-

Input data

In the input data, ‘id’ represents the key. State transformation will be maintained for each key. ‘alphabet’ represents the input symbol and ‘measure’ represents the input_measure. Both these columns will be used in the transition function to change the state of a particular key.

Timeouts in state transformation:-

Earlier in this article we discussed a use case of emitting an alert based on user inactivity (login expiry on social media platforms). In structured streaming, this can be achieved with the help of timeouts. Timeout dictates how long we should wait before timing out some intermediate state. Timeouts can either be based on processing time (GroupStateTimeout.ProcessingTimeTimeout) or event time (GroupStateTimeout.EventTimeTimeout). If timeout is not required, then (GroupStateTimeout.NoTimeout()) should be used. When using timeouts, you can check for timeout first, before processing the values by checking the flag state.hasTimedOut .

Implementation of the finite state machine has been explained below:-

  1. Creation of source dataframe:-
String schema = “id STRING,alphabet STRING, measure INT”; Dataset<Row> input= spark.readStream().schema(schema).option(“header”,true).csv(“state/”);

We have created an input dataframe from the csv file source having columns ‘id’, ‘alphabet’ and ‘measure’ as explained above.

2. Use of GroupByKey and mapGroupsWithState KPI’s :-

Dataset<StateUpdate> output= input.groupByKey(new MapFunction<Row, String>() {
@Override
public String call(Row value) throws Exception {
return value.getString(0);
}
}, Encoders.STRING()).mapGroupsWithState(updateEvents, Encoders.bean(StateInfo.class), Encoders.bean(StateUpdate.class), GroupStateTimeout.NoTimeout());

GroupByKey function is used to group our data on the basis of a ‘key’ column. In this case, the ‘key’ column is ‘id’ which is at the 0th index of each row (value.getString(0)).

GroupByKey creates a KeyValueGroupedDataset which will act as an input for mapGroupsWithState and will produce only a single result row for each group. The input parameters for mapGroupsWithState function are name of the function that will be used for implementing transition function(rules), StateInfo class, StateUpdate Class and GroupStateTimeout function. We have used GroupStateTimeout.NoTimeout() in this example.

Our function in this case is updateEvents. Before explaining updateEvent function, let us first understand StateInfo and StateUpdate classes.

StateInfo class has been implemented to store the different states of the state diagram. Class definition has been given below:-

public static class StateInfo implements Serializable {
private int measure = 0;
private String statename;
public int getMeasure() {
return measure;
}
public void setMeasure(int measure) {
this.measure = measure;
}
public String getStatename() {
return statename;
}
public void setStatename(String statename) {
this.statename = statename;
}
}

StateUpdate class has been implemented to store current state with respect to each key. StateUpdate class is returned by updateEvents function after each micro-batch of input stream is executed.

Class definition has been given below:-

public static class StateUpdate implements Serializable {
private String id;
private String statename;
private int measure;
public StateUpdate() {
}
public StateUpdate(String id, String statename, int measure) {
this.id = id;
this.statename = statename;
this.measure = measure;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getStatename() {
return statename;
}
public void setStatename(String statename) {
this.statename = statename;
}
public int getMeasure() {
return measure;
}
public void setMeasure(int measure) {
this.measure = measure;
}
}

The updateEvents function has datatypes of key(String), dataset row(Row), StateInfo class and StateUpdateClass as parameters. The call function has been overridden to implement our transition function with the help of simple if-else conditions. If a state for a key does not exist, then an initial state is created for that key before all the rules are implemented for it.

GroupState<StateInfo> stores all the intermediate states that are present during the process. If a state for a key already exists in GroupState, then that state is extracted and used for further transitions instead of creating a new initial state. After applying various state transitions on states of each key, the intermediate state is again updated in this GroupState.

Code for updateEvents function has been given below. After all the rows of a particular key have been processed, the function returns the final state reached for that key.

MapGroupsWithStateFunction<String, Row, StateInfo, StateUpdate> updateEvents=
new MapGroupsWithStateFunction<String, Row, StateInfo, StateUpdate>() {
@Override
public StateUpdate call(String key, Iterator<Row> values, GroupState<StateInfo> state) throws Exception {
int measure=0;
StateInfo updatedState = new StateInfo();
String name=”first”;
/***
* If a state doesn’t exist(at the initial state), then a state is created with initial values
* State name = first
* State measure = 0
*/
if(!state.exists()){
updatedState.setStatename(name);
updatedState.setMeasure(measure);
state.update(updatedState);
}
StateInfo oldstate = state.get();
name = oldstate.getStatename();
measure = oldstate.getMeasure();
while(values.hasNext()){
Row temp = values.next();
String alpha = temp.getString(1);
int temp_measure = temp.getInt(2);
System.out.println(alpha+” “+name+” “+temp_measure);
if(name.equals(“first”) && alpha.equals(“a”)){
name = “second”;
measure = measure + temp_measure;
}
else if(name.equals(“first”) && alpha.equals(“b”)){
name = “third”;
measure = measure — temp_measure;
}
else if(name.equals(“first”) && alpha.equals(“c”)){
name = “third”;
}
else if(name.equals(“second”) && alpha.equals(“a”)){
name = “second”;
measure = measure + temp_measure;
}
else if(name.equals(“second”) && alpha.equals(“b”)){
name = “third”;
measure = measure — temp_measure;
}
else if(name.equals(“second”) && alpha.equals(“c”)){
name = “fourth”;
}
else if(name.equals(“third”) && alpha.equals(“a”)){
name = “third”;
measure = measure + temp_measure;
}
else if(name.equals(“third”) && alpha.equals(“b”)){
name = “third”;
measure = measure — temp_measure;
}
else if(name.equals(“third”) && alpha.equals(“c”)){
name = “fourth”;
}
else if(name.equals(“fourth”)){
return new StateUpdate(
key, state.get().getStatename(), state.get().getMeasure());
}
}
updatedState.setStatename(name);
updatedState.setMeasure(measure);
state.update(updatedState);
return new StateUpdate(
key, state.get().getStatename(), state.get().getMeasure());
}
};

Finally from the Dataset<StateUpdate> output we filter records having statename as ‘fourth’ as this is the final state in our state diagram and print the result for each micro-batch.

StreamingQuery query = output.filter(“statename = ‘fourth’”)
.writeStream()
.outputMode(“update”)
.format(“console”)
.start();
query.awaitTermination();

I hope this article has helped in understanding of basics of state transformations in spark. These basics can further be used for implementation of complex logics.

Code can be found on the link below.

https://github.com/parkar-12/Stateful-Trasformation-Spark/tree/master/statetest

--

--