Showing posts with label debugging. Show all posts
Showing posts with label debugging. Show all posts

Saturday, October 28, 2017

Debugging Keras Networks


Last week a colleague and I were trying to figure out why his network would crash with a NaN (Not a Number) error some 20 or so epochs into training. Lately I have also become more interested in tuning neural networks, so this was a good opportunity for me to suggest fixes based on reasoning about the network. The network itself was built with Keras, like all the other networks our team has built from scratch so far, although we have adapted some third party networks written in Caffe and Tensorflow as well.

Now Keras is great for fast development because of its high level API. It results in very expressive code that reads like how you would actually visualize the network in your head or on a piece of paper. Also, because Keras automates away so many things and provides reasonable default values for many of its parameters, there are fewer things programmers can make mistakes about. For example, this awesome post on How to unit test machine learning code is based on Tensorflow, and while some of the cases mentioned are possible in Keras, they are much less likely.

However, while it is very easy to go from design to code in Keras, it is actually a little harder to work with, compared to say Tensorflow or Pytorch, when things go wrong and you have to figure out what. However, Keras does offer some tools and hooks that allow you to do this. In this post I talk about some of these that we (re-)discovered for ourselves last week. If you have favorites that I haven't included, please let me know in the comments.

The example I will use throughout this post is a simple fully connected network that I built to recognize MNIST images. The code to train and evailate this network can be found here. The code to define and compile it is as follows:

1
2
3
4
5
6
7
8
9
model = Sequential()
model.add(Dense(512, activation="relu", input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(256, activation="relu"))
model.add(Dropout(0.2))
model.add(Dense(10, activation="softmax"))

model.compile(optimizer="adam", loss="categorical_crossentropy", 
              metrics=["accuracy"])

The first issue I have seen have have to do with sizing the intermediate tensors in the network. Keras only asks that you provide the dimensions of the input tensor(s), and it figure out the rest of the tensor dimensions automatically. The flip side of this convenience is that programmers may not realize what the dimensions are, and may make design errors based on this lack of understanding. Keras provides a model.summary() function that returns the output dimensions from each layer. I have found this very useful to get a better intuition about a network.

1
model.summary()

Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 512)               401920    
_________________________________________________________________
dropout_3 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 256)               131328    
_________________________________________________________________
dropout_4 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 10)                2570      
=================================================================
Total params: 535,818
Trainable params: 535,818
Non-trainable params: 0
_________________________________________________________________

If you need more granular information about the intermediate tensors, take a look at the About Keras layers page. You can get input shapes as well using some code like this:

1
2
for layer in model.layers:
    print(layer.name, layer.input.shape, layer.output.shape)

dense_1 (?, 784) (?, 512)
dropout_1 (?, 512) (?, 512)
dense_2 (?, 512) (?, 256)
dropout_2 (?, 256) (?, 256)
dense_3 (?, 256) (?, 10)

Another built-in diagnostic tool that I have been ignoring a bit so far is Tensorboard. Tensorboard was originally developed as part of the Tensorflow ecosystem, and allows Tensorflow developers to log certain things into a Tensorboard log file, which can later be used to visualize these logs graphically. The Keras project provides a way to write to Tensorboard using its TensorBoard callback. I learned to extract loss and other metrics from the output of model.fit() and plot it with matplotlib before the TensorBoard callback was popular, and have continued to use the approach mostly due to inertia. But the TensorBoard callback provides not only these plots, but the weight distributions for all the weights, biases and gradients. In case of networks where Embeddings and Images are involved, Tensorboard provides visualizations for them as well.

To invoke the Tensorboard callback, it needs to be defined and then declared in the callbacks queue in the model.fit() call.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
tensorboard = TensorBoard(log_dir=TENSORBOARD_LOGS_DIR, 
                          histogram_freq=1, 
                          batch_size=BATCH_SIZE, 
                          write_graph=True, 
                          write_grads=True, 
                          write_images=False, 
                          embeddings_freq=0, 
                          embeddings_layer_names=None, 
                          embeddings_metadata=None)
...
history = model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, 
                    epochs=NUM_EPOCHS,
                    validation_split=0.1,
                    callbacks=[..., tensorboard, ...])

Here are the kind of visualizations you can expect on Tensorboard. The best resource I have found on interpreting these visualizations are Dandelion Mané's talk at Tensorflow Developers Summit 2017 and the Tensorboard documentation on Histograms



As nice as the Tensorboard callback is, it may not work for you all the time. For one thing, it appears that it doesn't work with fit_generator. You may also want to log values which are not meant to be logged with the Tensorboard callback. You can do that by writing your own callback in Keras.

Here is a callback that will capture the L2 norm, mean and standard deviation for each weight tensor in the network for each epoch and at the end of training, dump these values out to screen.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from keras import backend as K
from keras.callbacks import Callback
import numpy as np
            
def calc_stats(W):
    return np.linalg.norm(W, 2), np.mean(W), np.std(W)

class MyDebugWeights(Callback):
    
    def __init__(self):
        super(MyDebugWeights, self).__init__()
        self.weights = []
        self.tf_session = K.get_session()
            
    def on_epoch_end(self, epoch, logs=None):
        for layer in self.model.layers:
            name = layer.name
            for i, w in enumerate(layer.weights):
                w_value = w.eval(session=self.tf_session)
                w_norm, w_mean, w_std = calc_stats(np.reshape(w_value, -1))
                self.weights.append((epoch, "{:s}/W_{:d}".format(name, i), 
                                     w_norm, w_mean, w_std))
    
    def on_train_end(self, logs=None):
        for e, k, n, m, s in self.weights:
            print("{:3d} {:20s} {:7.3f} {:7.3f} {:7.3f}".format(e, k, n, m, s))

The on_epoch_end and on_train_end are basically event handlers which fire off when the epoch has ended and when training has ended respectively. The Callback interface defines 6 such events, for the beginning and end of batch, epoch and training. See the Keras callbacks documentation for a list and some more examples.

You could use the callback above to train for a small number of epochs and observe how these attributes of the weight tensors change. At some point, I would like to write these values to disk and then read them and chart them maybe using something like Pandas, but my Pandas-fu is not strong enough for that at this time. Here is the output of after 2 wpochs of training.

Train on 54000 samples, validate on 6000 samples
Epoch 1/2
54000/54000 [==============================] - 4s - loss: 0.2830 - acc: 0.9146 - val_loss: 0.0979 - val_acc: 0.9718
Epoch 2/2
54000/54000 [==============================] - 3s - loss: 0.1118 - acc: 0.9663 - val_loss: 0.0758 - val_acc: 0.9773
  0 dense_1/W_0           28.236  -0.002   0.045
  0 dense_1/W_1            0.283   0.003   0.012
  0 dense_2/W_0           20.631   0.002   0.057
  0 dense_2/W_1            0.205   0.008   0.010
  0 dense_3/W_0            4.962  -0.005   0.098
  0 dense_3/W_1            0.023  -0.001   0.007
  1 dense_1/W_0           30.455  -0.003   0.048
  1 dense_1/W_1            0.358   0.003   0.016
  1 dense_2/W_0           21.989   0.002   0.061
  1 dense_2/W_1            0.273   0.010   0.014
  1 dense_3/W_0            5.282  -0.008   0.104
  1 dense_3/W_1            0.040  -0.002   0.013

Another thing we can do is to look at the attributes of the outputs at each layer. I initially tried to build this as another callback, but ran into some problems, then decided on this standalone implementation which can be called after every few epochs of training to see if anything has changed. This is adapted from the Keras FAQ.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def get_outputs(inputs, model):
    layer_01_fn = K.function([model.layers[0].input, K.learning_phase()], 
                             [model.layers[1].output]) 
    layer_23_fn = K.function([model.layers[2].input, K.learning_phase()],
                             [model.layers[3].output])
    layer_44_fn = K.function([model.layers[4].input, K.learning_phase()],
                             [model.layers[4].output])
    layer_1_out = layer_01_fn([inputs, 1])[0]
    layer_3_out = layer_23_fn([layer_1_out, 1])[0]
    layer_4_out = layer_44_fn([layer_3_out, 1])[0]
    return layer_1_out, layer_3_out, layer_4_out

out_1, out_3, out_4 = get_outputs(Xtest[0:10], model)
print("out_1", calc_stats(out_1))
print("out_3", calc_stats(out_3))
print("out_4", calc_stats(out_4))

I suspect we can make this more generic by looking up the model.layers data structure, but since it is kind of hard to forecast every kind of model you will build and because you will be doing this once per model, a quick and dirty implementation like the above may be preferable to something nicer. As before, we can rerun this every couple of epochs and get back the L2 norm, mean and standard deviation of the output tensors at each layer, as shown below.

out_1 (15.320195, 0.15846619, 0.36553052)
out_3 (31.983685, 0.52617866, 0.82984859)
out_4 (1.4138139, 0.1, 0.29160777)

Finally, we also wanted to figure out what the gradients looked like. The code for this adapted heavily from Edward Banner's comment in Keras Issue 2226. Like the code for visualizing the outputs, this code also needs to be run after training for a few epochs and compared with the previous values of L2 norm, mean and standard deviation for the gradients at different layers in the network.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def get_gradients(inputs, labels, model):
    opt = model.optimizer
    loss = model.total_loss
    weights = model.weights
    grads = opt.get_gradients(loss, weights)
    grad_fn = K.function(inputs=[model.inputs[0], 
                                 model.sample_weights[0],
                                 model.targets[0],
                                 K.learning_phase()], 
                         outputs=grads)
    grad_values = grad_fn([inputs, np.ones(len(inputs)), labels, 1])
    return grad_values

gradients = get_gradients(Xtest[0:10], Ytest[0:10], model)
for i in range(len(gradients)):
    print("grad_{:d}".format(i), calc_stats(gradients[i]))

As before, the output below shows how the L2 norm, mean and standard deviation of the gradients at each layer. As with the output tensors, we train the network for 2 epochs, then then run this block of code. As you can guess, this sort of debugging works really well with an interactive development environment such as Jupyter Notebooks.

grad_0 (1.7725379, 1.1711028e-05, 0.0028093776)
grad_1 (0.17403033, 3.4195516e-05, 0.0076910509)
grad_2 (1.2508092, -7.3888972e-05, 0.003460743)
grad_3 (0.12154519, -0.00047613602, 0.0075816377)
grad_4 (1.5319482, 4.8748915e-11, 0.030318365)
grad_5 (0.10286356, -4.6566129e-11, 0.032528315)

That's all I had for today. The example network I have used here is quite simple, but these same ideas and tools can be used to debug more complex networks as well. These tools were built based on discussions betwwen my colleague and I last week, and the code is available here. I am sure many of you have your own favorite tools and tricks. If so, and you are okay with sharing, please let us know in the comments.

Saturday, May 01, 2010

Debugging XML with Apache XMLRPC

Its been quite insane the last couple of months at work, which is why I haven't been posting as frequently as I would like. I usually do the work I write about on my commute to and from work, and I've either been too mentally exhausted to do anything, or too busy debugging work related problems in my head. To those who have been kind enough to comment, I apologize for not getting back sooner, but I hope you understand - I will get to them as soon as I can.

As you know, I have been trying to interface a Java based publishing system with the Drupal CMS - the interface is over XMLRPC. I have a custom module which traps publish/unpublish events for various content types, and sends over a map of name value pairs for the Java publishing system to persist, where it is used by the web front end. On the Java side, I use Apache XMLRPC in server mode. There are also a few cases where I call Drupal's XMLRPC service using an Apache XMLRPC client.

One thing that struck me early on is the opacity of the Apache XMLRPC library (I find Drupal almost equally opaque, but that is probably because of a combination of my relative inexperience with Drupal and the dynamic nature of PHP). I mean, I am using a library for generating and parsing XMLRPC because I am either lazy or smart (depending on your point of view), not because reading (or writing) XML makes my head hurt. In fact, because all my transactions involve a (almost) black box (Drupal) at one end or the other, being able to see the XML request and response can help me develop and debug the other end that much faster.

I looked up the web for solutions to this problem, but I could not find what I was looking for - namely, some sort of switch to turn XMLRPC logging on and off in Apache XMLRPC. I did find some advice to proxy the request through a logging tool such as netcat, which leads me to believe that the feature I am looking for does not exist, and that the Apache XMLRPC team feels that such a feature is not important/useful enough to implement. I could be wrong, though - would appreciate corrections and pointers.

I've been getting by so far with a dummy handler on the server which just spits out the request XML in the server logs - obviously the request would actually "fail" because the handler cannot respond, because by the time the handlers have a chance to get at the request, its already been consumed. I've been meaning to do something nicer and more elegant, but just didn't have the time.

A few days ago, things came to a head when a really simple XMLRPC request to Drupal (user.login) resulted in the following exception from my client. Notice how completely useless it is.

1
2
3
4
5
6
7
8
    [junit] Testcase: testLogin took 0.499 sec
    [junit]     Caused an ERROR
    [junit] Failed to parse server's response: The markup in the document follow
ing the root element must be well-formed.
    [junit] org.apache.xmlrpc.client.XmlRpcClientException: Failed to parse serv
er's response: The markup in the document following the root element must be wel
l-formed.
    [junit] ...

That's when I decided that I really needed to stop wasting cycles trying to figure out these kind of issues, and spend some time instrumenting the code to really see whats going on with the XML. My approach involves extending Apache XMLRPC to do this. I describe the results of my efforts in this post. I've been using this for about 3 days now and find it incredibly useful. Hopefully you will too.

Client Side

A typical Java XMLRPC client in my codebase looks like this. Its basically copied straight off the Apache XMLRPC Documentation.

1
2
3
4
5
6
7
8
9
    XmlRpcClientConfigImpl config = new XmlRpcClientConfigImpl();
    config.setServerURL(new URL("http://localhost/services/xmlrpc"));
    XmlRpcClient client = new XmlRpcClient();
    client.setTransportFactory(new XmlRpcCommonsTransportFactory(client));
    client.setConfig(config);
    Map<String,Object> data = new HashMap<String,Object>();
    // populate the map...
    Object ret = client.execute("methodName", new Object[] {data});
    // check the return value...

In order to get the actual XML requests and response, I subclass the XmlRpcCommonsTransportFactory, and in my factory, return a subclass of XmlRpcCommonsTransport, which logs the request and response. To use this, all we need to do is to change this line in the above code:

1
    client.setTransportFactory(new CustomXmlRpcCommonsTransportFactory(client));

In fact, because the logging is triggered by the setting in the log4j.properties file, you can just switch out the XmlRpcCommonsTransportFactory with my custom version - if your logging is turned off for this class, then it behaves exactly like the original factory. Here's the code - the actual transport is modeled as an inner class within the factory, since it is only ever called from the factory.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Source: src/main/java/com/mycompany/myapp/xmlrpc/CustomXmlRpcCommonsTransportFactory.java
package com.mycompany.myapp.xmlrpc;

import java.io.ByteArrayInputStream;
import java.io.InputStream;

import org.apache.xmlrpc.XmlRpcException;
import org.apache.xmlrpc.client.XmlRpcClient;
import org.apache.xmlrpc.client.XmlRpcCommonsTransport;
import org.apache.xmlrpc.client.XmlRpcCommonsTransportFactory;
import org.apache.xmlrpc.client.XmlRpcTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CustomXmlRpcCommonsTransportFactory extends
    XmlRpcCommonsTransportFactory {

  private final Logger logger = LoggerFactory.getLogger(getClass());
  
  public CustomXmlRpcCommonsTransportFactory(XmlRpcClient pClient) {
    super(pClient);
  }
  
  @Override
  public XmlRpcTransport getTransport() {
    return new LoggingTransport(this);
  }
  
  private class LoggingTransport extends XmlRpcCommonsTransport {

    public LoggingTransport(CustomXmlRpcCommonsTransportFactory pFactory) {
      super(pFactory);
    }

    /**
     * Logs the request content in addition to the actual work.
     */
    @Override
    protected void writeRequest(final ReqWriter pWriter) throws XmlRpcException {
      super.writeRequest(pWriter);
      if (logger.isDebugEnabled()) {
        CustomLoggingUtils.logRequest(logger, method.getRequestEntity());
      }
    }

    /**
     * Logs the response from the server, and returns the contents of
     * the response as a ByteArrayInputStream.
     */
    @Override
    protected InputStream getInputStream() throws XmlRpcException {
      InputStream istream = super.getInputStream();
      if (logger.isDebugEnabled()) {
        return new ByteArrayInputStream(
          CustomLoggingUtils.logResponse(logger, istream).getBytes());
      } else {
        return istream;
      }
    }
  }
}

Since I used a similar approach to log XML requests and responses on the server side as well, I decided to move my logging code (which includes prettifying the XML for readability) into a common utilities class. Here is the Logging utilities class. Included in the class is a method to prettify the XML request and response - Drupal sends out a nicely formatted XML chunk, but Apache XMLRPC sends it out in one long line (for performance I think) - I found a nice way to do this here using just the standard Java libraries.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// Source: src/main/java/com/mycompany/myapp/xmlrpc/CustomLoggingUtils.java
package com.mycompany.myapp.xmlrpc;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.io.StringWriter;

import javax.xml.transform.OutputKeys;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.stream.StreamResult;
import javax.xml.transform.stream.StreamSource;

import org.apache.commons.httpclient.methods.RequestEntity;
import org.apache.commons.io.IOUtils;
import org.apache.xmlrpc.XmlRpcException;
import org.slf4j.Logger;

public class CustomLoggingUtils {

  public static void logRequest(Logger logger, 
      RequestEntity requestEntity) throws XmlRpcException {
    ByteArrayOutputStream bos = null;
    try {
      logger.debug("---- Request ----");
      bos = new ByteArrayOutputStream();
      requestEntity.writeRequest(bos);
      logger.debug(toPrettyXml(logger, bos.toString()));
    } catch (IOException e) {
      throw new XmlRpcException(e.getMessage(), e);
    } finally {
      IOUtils.closeQuietly(bos);
    }
  }

  public static void logRequest(Logger logger, String content) {
    logger.debug("---- Request ----");
    logger.debug(toPrettyXml(logger, content));
  }

  public static String logResponse(Logger logger, InputStream istream) 
      throws XmlRpcException {
    BufferedReader reader = null;
    try {
      reader = new BufferedReader(new InputStreamReader(istream));
      String line = null;
      StringBuilder respBuf = new StringBuilder();
      while ((line = reader.readLine()) != null) {
        respBuf.append(line);
      }
      String response = respBuf.toString();
      logger.debug("---- Response ----");
      logger.debug(toPrettyXml(logger, respBuf.toString()));
      return response;
    } catch (IOException e) {
      throw new XmlRpcException(e.getMessage(), e);
    } finally {
      IOUtils.closeQuietly(reader);
    }
  }

  public static void logResponse(Logger logger, String content) {
    logger.debug("---- Response ----");
    logger.debug(toPrettyXml(logger, content));
  }

  private static String toPrettyXml(Logger logger, String xml) {
    try {
      Transformer transformer = 
        TransformerFactory.newInstance().newTransformer();
      transformer.setOutputProperty(OutputKeys.INDENT, "yes");
      transformer.setOutputProperty(
        "{http://xml.apache.org/xslt}indent-amount", "2");
      StreamResult result = new StreamResult(new StringWriter());
      StreamSource source = new StreamSource(new StringReader(xml));
      transformer.transform(source, result);
      return result.getWriter().toString();
    } catch (Exception e) {
      logger.warn("Can't parse XML");
      return xml;
    }
  }
}

Server Side

On the server side, I use a similar approach (ie, subclassing) as on the client side. In my code, I use the XmlRpcServletServer embedded inside a Spring controller, like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
  @PostConstruct
  protected void init() throws Exception {
    XmlRpcServerConfigImpl config = new XmlRpcServerConfigImpl();
    config.setBasicEncoding(encoding);
    config.setEnabledForExceptions(enabledForExceptions);
    config.setEnabledForExtensions(enabledForExtensions);

    service = new XmlRpcServletServer();
    service.setConfig(config);
    ...
  }

  @RequestMapping(value="/someMethod", method=RequestMethod.POST)
  public void publish(HttpServletRequest request, 
      HttpServletResponse response) throws Exception {
    service.execute(request, response);
  }

I extend the XmlRpcServletServer and override its execute() method to look at the log4j setting and optionally log its request and response XML to the server logs. This is done inline with the code, a request comes in, is logged to server logs, then acted upon by the execute method which generates the response. Before being sent back to the client, the response is logged on the server logs.

This is slightly more involved. It involves wrapping the request and response parameters in request and response wrapper objects, which in turn return subclasses of ServletInputStream and ServletOutputStream, where the actual magic happens.

The custom ServletInputStream writes the real InputStream that it wraps into a ByteArrayInputStream on construction and logs the request contents. In the overriden read() method, it returns bytes from the ByteArrayInputStream instead of the real InputStream. The custom ServletOutputStream wraps the real OutputStream and has an overriden write() method which copies the bytes into a StringBuilder buffer as they come in. On close(), the response is logged from the StringBuilder and the real OutputStream is closed. The idea is derived from this page. Here is the code.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Source: src/main/java/com/mycompany/myapp/xmlrpc/CustomXmlRpcServletServer.java
package com.mycompany.myapp.xmlrpc;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;

import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;

import org.apache.xmlrpc.webserver.XmlRpcServletServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CustomXmlRpcServletServer extends XmlRpcServletServer {

  private final Logger logger = LoggerFactory.getLogger(getClass());
  
  @Override
  public void execute(HttpServletRequest request, 
      HttpServletResponse response) throws ServletException, IOException {
    if (logger.isDebugEnabled()) {
      super.execute(new LoggingRequestWrapper(request), 
        new LoggingResponseWrapper(response));
    } else {
      super.execute(request, response);
    }
  }

  private class LoggingRequestWrapper extends HttpServletRequestWrapper {

    private HttpServletRequest originalRequest;
    private LoggingServletInputStream loggingInputStream = null;
    
    public LoggingRequestWrapper(HttpServletRequest request) {
      super(request);
      this.originalRequest = request;
    }
    
    @Override
    public ServletInputStream getInputStream() throws IOException {
      loggingInputStream = new LoggingServletInputStream(
        originalRequest.getInputStream());
      return loggingInputStream;
    }
    
    @Override
    public BufferedReader getReader() throws IOException {
      return new BufferedReader(new InputStreamReader(getInputStream()));
    }
  }
  
  private class LoggingServletInputStream extends ServletInputStream {

    private ServletInputStream istream;
    private ByteArrayInputStream standinInputStream;
    
    public LoggingServletInputStream(ServletInputStream istream) 
        throws IOException {
      this.istream = istream;
      ByteArrayOutputStream bos = new ByteArrayOutputStream();
      byte[] buf = new byte[4096];
      int n = 0;
      while (true) {
        n = istream.read(buf);
        if (n == -1) {
          break;
        }
        bos.write(buf, 0, n);
      }
      this.standinInputStream = new ByteArrayInputStream(bos.toByteArray());
      CustomLoggingUtils.logRequest(logger, new String(bos.toByteArray()));
    }
    
    @Override
    public int read() throws IOException {
      int c = standinInputStream.read();
      return c;
    }
  }
  
  private class LoggingResponseWrapper extends HttpServletResponseWrapper {

    private HttpServletResponse originalResponse;
    private LoggingServletOutputStream loggingOutputStream = null;
    
    public LoggingResponseWrapper(HttpServletResponse response) {
      super(response);
      this.originalResponse = response;
    }
    
    @Override
    public ServletOutputStream getOutputStream() throws IOException {
      this.loggingOutputStream = new LoggingServletOutputStream(
        originalResponse.getOutputStream());
      return loggingOutputStream;
    }
    
    @Override
    public PrintWriter getWriter() throws IOException {
      return new PrintWriter(new OutputStreamWriter(getOutputStream()));
    }
  }

  private class LoggingServletOutputStream extends ServletOutputStream {

    private ServletOutputStream ostream;
    private StringBuilder buf;
    
    public LoggingServletOutputStream(ServletOutputStream ostream) {
      this.ostream = ostream;
      this.buf = new StringBuilder();
    }
    
    @Override
    public void write(int b) throws IOException {
      buf.append((char) b);
      ostream.write(b);
    }

    @Override
    public void close() throws IOException {
      if (buf.length() > 0) {
        CustomLoggingUtils.logResponse(logger, buf.toString());
      }
      ostream.close();
    }
  }
}

To turn logging on and off, you will need to tweak your log4j.properties file. Basically, if the logging for either of these is set to DEBUG, then it will log, otherwise it will not.

Oh, and by the way, remember the malformed XML exception that started this off in the first place? Turns out that I had the Devel module turned on in my Drupal installation, so the trace of the SQLs that were being executed were also being echoed back after the closing <methodResponse> tag in the response.

Saturday, February 09, 2008

Debugging and Profiling with Eclipse

This post contains some settings I use for remote debugging web applications using the Jetty and Tomcat containers, and profiling web applications deployed on a remote Tomcat server, using the Eclipse IDE. By remote I mean connecting over a socket, the container can (and does in my case, unless I am connecting from home) listen on a port on the local host. The stuff here is hardly original, it has been gleaned from various web pages and blogs, which I reference in the appropriate places. If you use (or are considering using) Eclipse and want to know how to do remote debugging and profiling, this information may be of some use to you.

Debugging

I have been using the Eclipse IDE (with the MyEclipse extension) for about 3 years now. Most of the time, when debugging, I just use logger.debug() calls within the code to see whats going on. I do know how to debug using the Eclipse Debug perspective, but I guess its just a habit I developed, and old habits die hard. I don't even use Eclipse's CVS perspective anymore, based on some bad experiences at a previous company where I tried but ended up inadvertently removing from CVS code that I removed locally in my IDE (it was incorrect usage on my part). However, lately, I am starting to find debugging very useful, mainly because of the long stop-deploy-start cycle for our main web application.

Unlike a lot of IDE users, I like to run my web container from the command line rather than from the IDE. This is because of two reasons. First, I think the primary goal should be being able to build a WAR file using Ant (or Maven) and being able to deploy to a container. A lot of IDEs make you go through various hoops to make the webapp "compliant", where the definition of what constitutes compliance can vary from IDE to IDE. As an Eclipse user, I have been a minority at my last two jobs, where the majority of Java developers use IDEA, so it usually turns out that I have to make Eclipse comply with what IDEA thinks is a webapp. Second, having to stop and restart the app within a container running within your IDE involves using your mouse (or in case of a laptop, your touchpad), which is way less convenient than the command line with command-history enabled.

We run and develop our main web application using Tomcat. I have been building Maven apps for quite a while now, and I tend to use the Maven-Jetty plugin because its so much more convenient. For Maven webapps, I tend to do most of my development using Jetty, then deploy to the Tomcat server. The upshot is that I need to be able to debug using remote Tomcat and Jetty instances.

Remote Debugging with Tomcat

The information here is from the Tomcat FAQ Wiki. Basically, you add this in to the $CATALINA_HOME/bin/setenv.sh file. My CATALINA_HOME is at /opt/apache-tomcat-5.5.25. If you already have a JAVA_OPTS defined for application-specific stuff, just add the stuff below to your JAVA_OPTS.

1
2
3
# /opt/apache-tomcat-5.5.25/bin/setenv.sh
export JAVA_OPTS="-Xdebug \
  -Xrunjdwp:transport=dt_socket,address=8787,server=y,suspend=n"

The address=8787 enables a debug listener on Tomcat that Eclipse can connect to to get debug information. On the Eclipse, side, open the Debug Launch Configuration Dialog by clicking "Run > Open Debug Dialog". On the left pane of the dialog, find "Remote Java Application", select and right-click (or click on the New icon on the top). This will open up a Dialog for setting parameters for a Debug Launch configuration. Here are my values:

Tab name Property name Property value Description
- Name Tomcat (Pluto:8080) Can be any name you want to give it. Mine says what and where
Connect Project hl-www This is your project name
Connect Connection Type Standard - Socket Attach Connect over a socket
Connect Connection Properties : Host pluto.healthline.com DNS name of the host, could be an IP address (I think)
Connect Connection Properties : Port 8787 Same port as specified in address above
Connect Allow termination of remote VM No This is really your choice, I just don't want it.
Source Source Lookup Path Select your project This is so you can see the sources as you debug
Source Source Lookup Path Select any other source jars you have This is so you can see the sources as you debug
Common Display in Favorites Menu Yes This adds the config as a bookmark under the debug icon.

Deploy your app to the Tomcat container and restart Tomcat. In Eclipse, switch to the Debug perspective and a breakpoint in in your code (say a controller you want to call). In Eclipse's Debug perspective, [Alt]-[Shift]-B allows you to set (or unset) breakpoints at particular points in your code. Open up a browser and point to the page you want to debug. Bringing the page up will activate the debugger in Eclipse and you will see the code where you set the breakpoint being highlighted, with the top right corner containing the variables to be inspected. You can use [F6] through [F8] keys to step over, into and out of breakpoints. You probably know how to take it from here.

Remote Debugging with the Maven-Jetty plugin

Information for this comes from Dan Allen's blog post Remote Debugging with Jetty. Unlike Tomcat, this time you have to set the debugging parameters within MAVEN_OPTS since Maven runs its classworlds Launcher instead of Java. The MAVEN_OPTS need to be set in your configuration (either in your ~/.bash_profile or in a shell script that calls the mvn jetty6:run command). As before, if you already have other stuff in your MAVEN_OPTS, the stuff below needs to go after that.

1
2
export MAVEN_OPTS="-Xdebug -Xnoagent -Djava.compiler=NONE \
  -Xrunjdwp:transport=dt_socket,address=8781,server=y,suspend=n"

You also need to disable the Jetty maxIdleTime interval by setting it to 0. This is done in the pom.xml file like so:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
<project ...>
  ...
  <build>
    <plugins>
      <plugin>
        <groupId>org.mortbay.jetty</groupId>
        <artifactId>maven-jetty6-plugin</artifactId>
        <configuration>
          <scanIntervalSeconds>10</scanIntervalSeconds>
          <connectors>
            <connector implementation="org.mortbay.jetty.nio.BlockingChannelConnector">
              <port>8081</port>
              <maxIdleTime>0</maxIdleTime>
            </connector>
          </connectors>
        </configuration>
        <dependencies>
          <dependency>
            <groupId>org.apache.geronimo.specs</groupId>
            <artifactId>geronimo-j2ee_1.4_spec</artifactId>
            <version>1.0</version>
            <scope>provided</scope>
          </dependency>
        </dependencies>
      </plugin>
    </plugins>
  </build>
</project>

On the Eclipse side, the setup is identical to the Tomcat setup described above. Simply change the name (mine is called Jetty (Pluto:8081)) and the port number of the listener to what you set it to in MAVEN_OPTS (mine is 8781).

Profiling

Recently, I needed to profile a web application I wrote. It was taking 4-8 seconds to serve a single page on a production class machine, compared to an expectation of about 40-80 milliseconds. Response times on my much less powerful development box, while not 40-80ms, were tolerable. My initial reaction was to put StopWatch calls within the handleRequest() method of the Controller, timing the blocks which I thought could do with improvement. That detected some places where it was spending more time than I thought it should, so I fixed those, but the pages were still dog slow on production. Moreover, it seemed that response times were degrading under load, and load on the database machines was spiking so as to make them almost unusable. What I needed was a profiler, but I did not know how to set one up, much less know how to run it and interpret the results.

However, good things sometimes happen to bad programmers, and our local performance guru was kind enough to set up a profiling instance on his Netbeans IDE (he is an IDEA user, but he uses Netbeans for its awesome profiling tool) and run a profile for me. It did identify several more hotspots in the code that could be optimized, and I fixed them. The performance did improve somewhat as a result, but we were still seeing spikes on the database machines.

The problem turned out to be contention for the same database resource with another web application, which I figured out by just thinking through it and looking through the code. However, the profiler output helped me weed out the unnecessary stuff quickly. So although the best way to find performance problems is still, in my opinion, just trolling through code coupled with an understanding of the program flow, a profiler makes the process much faster, because it has already told you what you are not looking for.

While I now know (thanks to the same guy who helped me out with the performance numbers before) how to do profiling with the Netbeans IDE, I wanted to do this from within Eclipse using the TPTP plugin, so what follows is my setup for doing that.

Remote Profiling Tomcat apps

Information from this comes from this profiling java blog post, which has a link to a Eclipse-TPTP setup Howto on Windows XP, which I adapted for my use. TPTP needs a client component to be installed in the Eclipse IDE (the TPTP plugin), and an agent component RAServer which mediates between the performance data from the Tomcat server and the Eclipse TPTP client. Huge amounts of profiling data are transferred as XML documents, so using this from a remote (not localhost) client is very slow. Therefore, three things need to be setup to use TPTP to profile remote apps under Eclipse.

First, we need to download the TPTP plugin. If you are using a recent version of Eclipse (I am using 3.3.1.1) then you can get the plugin from the Europa Discovery Site. Simply click on "Help > Software Updates > Find and Install > Search for new features to install", then select the Performance and Monitoring features and click on "Select Required". This will download the TPTP plugin to your IDE. Restart your IDE to see the Profile icon on the toolbar, and "Run > Profile..." entries in your menu. The complete procedure is explained in detail in the Installing TPTP using Update Manager page.

Second, we need to install the agent component. This is available as a separate download for the particular architecture and operating system from the TPTP home page (scroll down to Agent Controller). Here is a link to the one I used.

Setting this up was easy, but not totally straightforward. The first step is to unzip the download into /opt/tptpdc-4.1.0, then set up the following environment variables in your ~/.bash_profile and source it. Here is the snippet from my ~/.bash_profile file.

1
2
3
4
# TPTP settings
export RASERVER_HOME=/opt/tptpdc-4.1.0
export PATH=$RASERVER_HOME/bin:$PATH
export LD_LIBRARY_PATH=$RASERVER_HOME/lib:$LD_LIBRARY_PATH

We then need to navigate to $RASERVER_HOME/bin, then run SetConfig.sh (the very first time only) to set up the XML file for RAServer to work. Then from the same directory, we need to start the server using RAStart.sh (the corresponding stop script is RAStop.sh in the same directory). However, when I ran the RAStart.sh script, I discovered that there were missing libraries on my Fedora Core 7 system. To fix that, I had to download the libstdc++ compatibility RPM from the RPMFind page and install it using the following command:

1
$ sudo rpm -ivh compat-libstdc++-296-2.96-138.i386.rpm

Finally, we need to set up the JAVA_OPTS environment variable in the $CATALINA_HOME/bin/setenv.sh file, like so. Also, since we are starting Tomcat with the profiling instrumentation enabled, I found that it would complain about missing libraries, which went away after I added the RASERVER_HOME paths to PATH and LD_LIBRARY_PATH to the setenv.sh file.

1
2
3
4
5
# /opt/apache-tomcat-5.5.25/bin/setenv.sh
export RASERVER_HOME=/opt/tptpdc-4.1.0
export PATH=$RASERVER_HOME/bin:$PATH
export LD_LIBRARY_PATH=$RASERVER_HOME/lib:$LD_LIBRARY_PATH
export JAVA_OPTS="-XrunpiAgent:server=enabled"

To start using profiling, I deployed the web application to Tomcat, started RAServer, then started Tomcat.

On the Eclipse side, I built a Profiling Launch configuration by clicking "Run > Profile", then right-clicking New on "Attach to Agent" on the left pane of the resulting dialog. Here are the settings for my IDE.

Tab name Property name Property value Description
- Name WWW (Pluto:8080) Can be anything. Mine says what and where.
Hosts Default Hosts Added pluto.healthline.com:10002 localhost:10002 was already there and could not remove it. Adding the new one and selecting it makes that the current host.
Agents Available Agents Click on Refresh to get the standard Agent exposed by RAServer and select it. localhost:10002 was already there and could not remove it. Adding the new one and selecting it makes that the current host.
Destination Profiling Project I just chose the same project name I was monitoring. -
Destination Monitor Choose Default Monitor (the default) -
Common Display in Favorites Menu Yes Makes the configuration appear when the Profile icon is clicked.

Once this is done, switch to the profiling perspective. If the agent has been discovered, Eclipse will attach to it and start collecting statistics. Since a web app's job is to serve pages, what I do is to aim a URL generating script at the application. Here is an example of a Python script that reads a list of URLs from a text file and hits the app with the URLs.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/python
# Simple harness to run the URLs from the systemtesturls.txt manually
import sys
import string
import httplib
import time

def usage():
  print "Usage:" + sys.argv[0] + " www.myhost.com:80 /path/to/urllist"
  sys.exit(-1)

def main():
  if (len(sys.argv) != 3):
    usage()
  host = sys.argv[1]
  urllist = open(sys.argv[2], 'r')
  totaltime = 0
  maxtime = 0
  mintime = 0
  lno = 0
  okresults = 0
  badresults = 0
  while 1:
    urlline = urllist.readline()
    if (not urlline):
      break
    if (urlline.startswith("#")):
      continue
    lno = lno + 1
    testurl = string.rstrip(urlline)
    print "Testing (" + str(lno) + "): " + testurl
    start = time.clock()
    conn = httplib.HTTPConnection(host)
    conn.request("GET", testurl)
    resp = conn.getresponse()
    status = resp.status
    if (status == 200):
      okresults = okresults + 1
    else:
      badresults = badresults + 1
      print "Error:", status, resp.reason, str(lno)
    data = resp.read()
    conn.close()
    stop = time.clock()
    elapsed = stop - start
    if (elapsed < mintime):
      mintime = elapsed
    if (elapsed > maxtime):
      maxtime = elapsed
    totaltime = totaltime + elapsed
  urllist.close()
  print "quality results, Ok=" + str(okresults) + ", Bad=" + str(badresults) + ", Total=" + str(lno)
  print "timing results: min(s)=" + str(mintime) + ", max(s)=" + str(maxtime) + ", avg(s)=" + str((totaltime / lno))

if __name__ == "__main__":
  main()

Once the script completes, you can stop the profiling. I was able to generate three reports from it - Execution Statistics, Memory Statistics and Coverage Statistics. Of these, I found the Execution statistics the most useful since it told me how many times a method was called, and what processing time on average was spent in each of these methods. Undoubtedly I will find more use for the other reports in the future, but for the moment I am happy to have profiling working under Eclipse.

Update Feb 16 2008

I was able to profile using Maven's Jetty plugin as well recently. Instead of setting the string "-XrunpiAgent:server=enabled" to JAVA_OPTS, we just set it to MAVEN_OPTS instead, then run mvn -o jetty6:run. The RASERVER_HOME, LD_LIBRARY_PATH and PATH setting also needs to be in there for the agent to work correctly. So my new improved jetty.sh now looks like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
#!/bin/bash
BASE_MAVEN_OPTS="-Xmx2048m"
DEBUG_MAVEN_OPTS="-Xdebug -Xnoagent -Djava.compiler=NONE -Xrunjwdb:transport=dt_cket,address=8781,server=y,suspend=n"
PROFILE_MAVEN_OPTS="-XrunpiAgent:server=enabled"
case $1 in
  'debug')
    MAVEN_OPTS=$BASE_MAVEN_OPTS" "$DEBUG_MAVEN_OPTS
    ;;
  'profile')
    export RASERVER_HOME=/opt/tptpdc-4.1.0
    export PATH=$RASERVER_HOME/bin:$PATH
    export LD_LIBRARY_PATH=$RASERVER_HOME/lib:$LD_LIBRARY_PATH
    MAVEN_OPTS=$BASE_MAVEN_OPTS" "$PROFILE_MAVEN_OPTS
    ;;
  *)
    MAVEN_OPTS=$BASE_MAVEN_OPTS
    ;;
esac
export MAVEN_OPTS
mvn -o jetty6:run

To start a normal session, I just call jetty.sh, for debugging and profiling, I call jetty.sh debug and jetty.sh profile respectively. On the Eclipse side, I create a profile configuration in the same way as for Tomcat, by attaching the profiling client to the running Java application. The RAServer detects the Java app that is exposing profiling information, and automatically discovers it.

Update Feb 27 2008

This post was republished by the folks at SYS-CON Media in their Open Web Developer's Journal and is available here. Goes to show that one should be careful about what one writes, it may end up anywhere :-). Thanks to Jeremy Geelan for making this happen.