Matplotlib in Machine Learning – Shiksha Online

Matplotlib in Machine Learning – Shiksha Online

6 mins read1.6K Views Comment
Updated on Nov 22, 2022 19:27 IST

In applied Statistics and Machine Learning, Data Visualization is one of the most important skills that helps in the qualitative understanding of the data at hand. This proves to help explore and extract relevant information from the data by identifying patterns, relationships, outliers, and much more. The article explores the concept of Matplotlib in Machine Learning.

2022_02_Matplotlib-for-Machine-Learning.jpg

Visualizations are the easiest way to analyze and intake information. Data Visualization also gives way to high-level data analysis in Exploratory Data Analysis (EDA). Python features multiple data visualization libraries – the most popular and widely used one among them being the Matplotlib Library. In this blog, we will be covering Matplotlib in Machine Learning in the following sections:

Introduction to Matplotlib

Matplotlib is an open-source plotting library that is used to create static 2D plots, although it does have some support for 3D visualizations as well. 

It is a comprehensive library that makes producing both simple and advanced plots straightforward and intuitive. 

It has applications in Python scripts, Jupyter notebook, and web application servers.

Installing Matplotlib

Let’s start with installing the library in your working environment first:

 
#Windows, Linus, MacOS users:
python -mpip install -U matplotlib
#To install Matplotlib in Jupyter Notebook:
pip install matplotlib
#To install Matplotlib in Anaconda Prompt:
conda install matplotlib
Copy code

Importing Matplotlib

Now let’s import the Matplotlib library along with the other libraries we might need today:

 
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
#or
import matplotlib.pyplot as plt
%matplotlib inline
Copy code

In Matplotlib, pyplot is used to create figures and change their characteristics.

The %matplotlib inline function allows for plots to be visible when using Jupyter Notebook.

Creating a Simple Plot using Matplotlib

A line plot is the most basic plot to create. Simply use plt.plot() as shown below:

 
#Example
plt.plot([32,26,43,41,16,37,29])
plt.ylabel('marks out of 50')
plt.show()
Copy code
Creating a Simple Plot using Matplotlib

If we provide a single list or array to the plot() function, Matplotlib assumes it is a sequence of y values, and automatically generates x values for you. To plot for x versus y, we can write the following command:

 
#Example
plt.plot([1,2,3,4,5,6,7], [32,26,43,41,16,37,29], 'g^')
plt.xlabel('roll no.')
plt.ylabel('marks out of 50')
plt.show()
Copy code
Chart, scatter chart

Description automatically generated

Do you see how we changed the type of plot above? For every x, y pair of arguments, there is an optional third argument which is the format string that indicates the color (g for green) and line type (^ for triangles) of the plot. 

The default format string is ‘b-‘, which is a solid blue line.

Working with Figures and Axes

Figure Object

The Figure object should be considered as your frame. It is the bounded space within which one or more graphs can be plotted.

plt.figure() is used to create the empty Figure object in Matplotlib. It has the following additional parameters:

  • figsize:  Figure dimension (width, height) in inches
  • dpi: Dots per inch
  • facecolor: Figure patch facecolor
  • edgecolor: Figure patch edge color
  • linewidth: Linewidth of the frame

Axes Object

A figure usually contains multiple axes (plots). The Axes object is the canvas on which you plot your graphs. Each Axes has a title, an X –label, and a Y –label.

  • add_axes() to add axes to the figure 
  • ax.set_title() for setting title
  • ax.set_xlabel() and ax.set_ylabel() for setting x and y-label respectively 
 
#creating fig
fig=plt.figure(figsize=[7, 5], facecolor='pink', edgecolor='b')
#adding axes to fig
ax = fig.add_axes([0,0,1,1])
ax.set_title("New Figure and Axes")
ax.set_xlabel('x-axis')
ax.set_ylabel('y-axis')
Copy code
adding axes to fig

Let’s add the line plot we created in the above example to our ‘New Figure and Axes’:

 
#Example
fig=plt.figure(figsize=[7, 5], facecolor='pink', edgecolor='b')
ax = fig.add_axes([0,0,1,1])
ax.set_title("Example Line Plot")
ax.set_xlabel('roll no.')
ax.set_ylabel('marks out of 50')
plt.plot([1,2,3,4,5,6,7],[32,26,43,41,16,37,29])
plt.show()
Copy code
New Figure and Axes

As discussed, let’s see how we can add multiple plots in a single figure:

Subplots 

We use pyplot.subplots to create a figure and a grid of subplots with a single call. The subplots() function returns a Figure object and an Axes object.  

Let’s add another plot to the above example. 

 
#Example
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[12,5], facecolor='pink', edgecolor='b')
fig.suptitle('Example Subplots')
ax1.plot([1,2,3,4,5,6,7],[32,26,43,41,16,37,29])
ax1.set_title("Class 1")
ax1.set_xlabel('roll no.')
ax1.set_ylabel('marks out of 50')
ax2.plot([1,2,3,4,5,6,7],[48,32,40,44,36,27,21], color='g')
ax2.set_title("Class 2")
ax2.set_xlabel('roll no.')
ax2.set_ylabel('marks out of 50')
plt.show()
Copy code
Subplots example

You can learn more about Matplotlib Subplots here.

Important Matplotlib Plots in Machine Learning 

Matplotlib provides a wide variety of plot formats to support various visualizations methods. The most popular ones are linked here:

Use of Matplotlib in Machine Learning 

As discussed above, the Matplotlib library is used during the Exploratory Data Analysis (EDA) and Data Visualization phases of an ML model building process. 

Let’s understand how EDA is done using Matplotlib with an example of Harmonized System. It was developed by the WCO (World Customs Organization) as a multipurpose international product nomenclature that describes the type of commodities imported or exported each year. This system is used by 200+ countries. It comprises about 5,000 commodity groups; each identified by a six-digit code (HS Code). We will make use of two datasets that contain records for import and export products. You can find them here. [hyperlink datasets]

Step 1 – Import the required libraries

 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
Copy code

Step 2 – Load the datasets

 
#Read the data from the uploaded csv files
data_export = pd.read_csv('export.csv')
data_import = pd.read_csv('import.csv')
Copy code

You can concatenate the two datasets as shown below:

 
#Concatenate the data
data_export['cat'] = 'E'
data_import['cat'] = 'I'
df = pd.concat([data_export,data_import],ignore_index=True)
df
Copy code
Concatenate the data

Now, let’s proceed with analyzing this data.

Step 3 – Perform EDA

Which are the top ten countries where the value of export is highest? 

You need to find out the top ten export destinations of India:

 
df1 = data_export.groupby('country').agg({'value':'sum'})
df1 = df1.sort_values(by='value', ascending = False)
df1 = df1[:10]
df1
Copy code
top ten export destinations of India

Let’s plot a bar graph for the same. We use the bar() to plot the bar graph, which is vertical by default. Use barh() to plot a horizontal bar graph:

 
#Plotting a horizontal bar plot
fig = plt.figure(figsize = (10, 5))
plt.barh(df1.index, df1.value)
plt.xlabel("Value")
plt.ylabel("Country")
plt.title("Country-wise Export")
plt.show()
Copy code
Plotting a horizontal bar plot

Find the trend in the trade deficit for India.

The trade deficit is the amount by which the cost of a country’s imports exceeds the value of its exports. 

Let’s plot a line chart and compare the trend of total import and export values for each year from India. This will give us a fair idea about the trade deficit:

 
#Plotting a simple line plot
fig = plt.figure(figsize = (10, 5))
df[df['cat'] == 'I'].groupby(['year'])['value'].sum().plot(c='orange');
df[df['cat'] == 'E'].groupby(['year'])['value'].sum().plot(c='purple');
Copy code
Plotting a simple line plot

Visualize the max export values to the UK for each year.

Export value of ‘ZINC AND ARTICLES THEREOF’ has been maximum to UK in any year:

 
df[(df['country'] == 'U K') & (df['cat'] == 'E')].groupby(['year']).max()[ ['value', 'Commodity']]
Copy code
Plotting a simple line plot 2

Let’s plot a bar graph for the max export values to UK for each year:

 
#Plotting a vertical bar graph
fig = plt.figure(figsize = (10, 5))
df[(df['country'] == 'U K') & (df['cat'] == 'E')].groupby(['year']).max() ['value'].plot(x='Year', y='Maximum value',kind="bar");
Copy code
plot a bar graph for the max export values to UK

Compare the means of import/export values of expensive commodities for each year.

Let’s say that if the value exceeds 1000, it is an expensive trade. We’ll plot a boxplot for this one:

 
thresh = 1000
df2 = df[(df.value > thresh)]
#Plotting a box plot
df2.boxplot(column = 'value', by = 'HSCode', figsize=(12,7))
Copy code
Plotting a box plot - df2.boxplot

Analyze the import values of commodities imported from Canada.

Plot the total values of imports each year using a pie chart:

 
x = pd.DataFrame(df[(df.country == 'CANADA') & (df.cat == 'I')].groupby( ['year']).sum()['value']).reset_index()
#Plotting a pie chart
plt.pie(x['value'],labels=x['year']);
plt.axis('equal')
plt.tight_layout()
plt.show()
Copy code
Plotting a pie chart

We can make the pie chart more readable by highlighting the maximum import value:

 
\n <pre class="python" style="font-family:monospace">\n <span style="color: #808080;font-style: italic">\n #Explode the slice with the highest imports\n
fig \n <span style="color: #66cc66">\n = plt.\n <span style="color: black">\n figure\n <span style="color: black">\n (figsize \n <span style="color: #66cc66">\n = \n <span style="color: black">\n (\n <span style="color: #ff4500">\n 10\n <span style="color: #66cc66">\n , \n <span style="color: #ff4500">\n 5\n <span style="color: black">\n )\n <span style="color: black">\n )\n
\n
plt.\n <span style="color: black">\n pie\n <span style="color: black">\n (x\n <span style="color: black">\n [\n <span style="color: #483d8b">\n 'value'\n <span style="color: black">\n ]\n <span style="color: #66cc66">\n ,labels\n <span style="color: #66cc66">\n =x\n <span style="color: black">\n [\n <span style="color: #483d8b">\n 'year'\n <span style="color: black">\n ]\n <span style="color: #66cc66">\n ,explode\n <span style="color: #66cc66">\n =\n <span style="color: black">\n (\n <span style="color: #ff4500">\n 0\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0.15\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0\n <span style="color: #66cc66">\n ,\n <span style="color: #ff4500">\n 0\n <span style="color: black">\n )\n <span style="color: #66cc66">\n ,startangle\n <span style="color: #66cc66">\n =\n <span style="color: #ff4500">\n 90\n <span style="color: #66cc66">\n ,autopct\n <span style="color: #66cc66">\n =\n <span style="color: #483d8b">\n '%1.1f%%'\n <span style="color: black">\n )\n <span style="color: #66cc66">\n ;\n
plt.\n <span style="color: black">\n axis\n <span style="color: black">\n (\n <span style="color: #483d8b">\n 'equal'\n <span style="color: black">\n )\n
plt.\n <span style="color: black">\n title\n <span style="color: black">\n (\n <span style="color: #483d8b">\n 'Imports by Canada'\n <span style="color: black">\n )\n
plt.\n <span style="color: black">\n tight_layout\n <span style="color: black">\n (\n <span style="color: black">\n )\n
plt.\n <span style="color: black">\n show\n <span style="color: black">\n (\n <span style="color: black">\n )\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: #483d8b">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: #483d8b">\n </span style="color: black">\n </span style="color: black">\n </span style="color: #66cc66">\n </span style="color: black">\n </span style="color: #483d8b">\n </span style="color: #66cc66">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #66cc66">\n </span style="color: black">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: black">\n </span style="color: #66cc66">\n </span style="color: #66cc66">\n </span style="color: black">\n </span style="color: #483d8b">\n </span style="color: black">\n </span style="color: #66cc66">\n </span style="color: #66cc66">\n </span style="color: black">\n </span style="color: #483d8b">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: black">\n </span style="color: #ff4500">\n </span style="color: #66cc66">\n </span style="color: #ff4500">\n </span style="color: black">\n </span style="color: #66cc66">\n </span style="color: black">\n </span style="color: black">\n </span style="color: #66cc66">\n </span style="color: #808080;font-style: italic">\n </pre class="python" style="font-family:monospace">
Copy code
imports by Canada

So, you see how easy Matplotlib makes it for us to visualize and analyze data? Once, you have found relevant patterns in your data you can go ahead with model development using Machine Learning algorithms.

Endnotes

Matplotlib is one of the oldest Python data visualization libraries, and thanks to its wealth of features and ease of use it is still one of the most widely used ones. Matplotlib was first released back in 2003 and has been continuously updated since. I hope this article helped you understand the concept of Matplotlib in Machine Learning.


Top Trending Articles:

Data Analyst Interview Questions | Data Science Interview Questions | Machine Learning Applications | Big Data vs Machine Learning | Data Scientist vs Data Analyst | How to Become a Data Analyst | Data Science vs. Big Data vs. Data Analytics | What is Data Science | What is a Data Scientist | What is Data Analyst

About the Author

This is a collection of insightful articles from domain experts in the fields of Cloud Computing, DevOps, AWS, Data Science, Machine Learning, AI, and Natural Language Processing. The range of topics caters to upski... Read Full Bio