Statistical Analysis and Visualization


Teaching: 40 min
Exercises: 0 min
  • How can I perform statistical analysis in Python?

  • How to visualize my data

  • Learn to perform statistical analysis with Scipy.stats module

  • Learn to visualize data with matplotlib

Statistical Analysis with Scipy

Similar as the Dada Analytics Tools in Excel, python can perform many statistic analysis as well. scipy module has many useful statistic tools. (more than Excel).


Assume you are interested in selling either energy drink or cola, and you want to know if the two categories has significant difference in average bottle sales. We can perform a t-test assuming different variance.
Firstly we need to import the stats module.

from scipy import stats

Then we need to use aggregate function to find the total bottle sold for each soda in each category. Remember the groupby function? Try to do it by yourself.

ct = inv_soda.groupby(['Item_Description','Category'], as_index=False).agg({"Bottles_Sold":"sum"})

Then, you get the total bottles sold for the two categories.

ct_eng = ct[ct["Category"] == "Energy Drink"]
ct_cola = ct[ct["Category"] == "Cola"]

Finally, you perform the t-test.

stats.ttest_ind(ct_eng['Bottles_Sold'],ct_cola['Bottles_Sold'], equal_var=False)

Then you get

Ttest_indResult(statistic=1.9035249229837687, pvalue=0.057480903679786859)

Well, the p value is not low enough to reject the null.




Many of you are experts in Excel graphing. You can do similar visualization with python. matplotlib is a good starter module for visualization in python. Firstly, let’s import the module.

%matplotlib inline
# In majority of cases, you will just use pyplot. So you can just import that.   
import matplotlib.pyplot as plt

Similar as Excel, you just need to feed in the values in x-axis and y-axis. Let’s start with a very basic example:

# create three points for the first line
x = [1,2,3]
y = [2,3,5]
# create three points for the second line (optional, I just want to show you that you can plot multiple lines in one graph)
x2 = [1.5,2.5,3.5]
y2 = [2.7,3.7,5.7]
# change size (set size before plotting)
# plot the graph  
plt.plot(x,y, label = "first line")
plt.plot(x2,y2, label = "second line")
# Set x, y limits  
plt.xlim(0, 5)
plt.ylim(0, 7)
# Add x, y label  
plt.xlabel("independent variable")
plt.ylabel("dependent variable")
# Add title
# Add legend  
# show the graph

You will get something like this:

Without a graphic interface like Excel, it is a little less intuitive. Carefully review the code, and try to understand what each line is doing. Note that pyplot.plot (you can click to read the documentation) has a lot of parameters. Basically, you can customize everything such as color, marker, alpha, etc.


Let’s go back to our data.
Assume you want to see the price distribution of the soda that cost less than $5, you can draw a histogram with pyplot.hist.

# set size
# feed in the data, set range, plot the graph
plt.hist(soda['Bottle_Cost'], range=(0,5))
# set x, y label
plt.ylabel("number of soda")
# Add title
plt.title("Price distribution")


This is a very basic example. You can customize a lot more parameters such as colors, bins, etc.

Scatter plot

A scatter plot is a two-dimensional data visualization that uses dots to represent the values obtained for two different variables.

# set figure size  
fig, ax = plt.subplots(figsize=(10,7))
# Plot data
ax.scatter(soda['Bottle_Cost'], soda['Bottle_Retail_Price'])
# Set title  
plt.title("Bottle_Retail_Price VS Bottle_Cost")
# Set x, y limit  
plt.xlim(1, 5)
plt.ylim(1, 10)




Assume you just opened a new convenience store. You want to see what categories of soda did other stores spend the most money on in their inventory.
What are the top three categories in terms of total cost? Visualize it with pie chart. Read the documentation yourself and try to figure out how to draw it. Hint: 1. Calculate total cost for each soda (arithmetic)

  1. Calculate total cost for each category (aggregation)
  2. Sort the aggregated value
  3. draw the graph (what is the x? what is the label?)
  4. set autopct parameter to ‘%.0f%%’


inv_soda["Total_Cost"] = inv_soda["Bottle_Cost"] * inv_soda["Bottles_Sold"]  
ct = inv_soda.groupby(['Category'])["Total_Cost"].agg(["sum"]).reset_index()
ct = temp.sort_values("sum", ascending = False)
plt.pie(ct["sum"], labels=ct["Category"], autopct='%.0f%%')

Key Points

  • Use scipy.stats.ttest_ind for t-test

  • There are many other statistical tools in scipy, you can read the documentation for more details

  • You can customize many parameters in your graph, you can read the documentation for more details

  • For both Scipy and matplotlib, the most difficult part is to preprocess your data. After that, you can just find the right function and feed your data into it.