View Javadoc
1   /*
2    * Copyright (C) 2007-2012 Argeo GmbH
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *         http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package org.argeo.slc.jsch;
17  
18  import java.io.IOException;
19  import java.io.InputStream;
20  import java.security.PrivilegedAction;
21  
22  import org.apache.commons.logging.Log;
23  import org.apache.commons.logging.LogFactory;
24  import org.argeo.slc.SlcException;
25  
26  import com.jcraft.jsch.JSch;
27  import com.jcraft.jsch.JSchException;
28  import com.jcraft.jsch.Logger;
29  import com.jcraft.jsch.Session;
30  import com.jcraft.jsch.UserAuthGSSAPIWithMIC;
31  
32  public abstract class AbstractJschTask implements Runnable {
33  	private final Log log = LogFactory.getLog(getClass());
34  
35  	private SshTarget sshTarget;
36  
37  	protected Session openSession() {
38  		if (sshTarget.getSession() != null) {
39  			Session session = sshTarget.getSession();
40  			if (session.isConnected()) {
41  				if (log.isTraceEnabled())
42  					log.debug("Using cached session to " + getSshTarget() + " via SSH");
43  				return session;
44  			}
45  		}
46  
47  		try {
48  			JSch jsch = new JSch();
49  			if (sshTarget.getUsePrivateKey() && sshTarget.getLocalPrivateKey().exists())
50  				jsch.addIdentity(sshTarget.getLocalPrivateKey().getAbsolutePath());
51  			Session session = jsch.getSession(getSshTarget().getUser(), getSshTarget().getHost(),
52  					getSshTarget().getPort());
53  
54  			session.setUserInfo(getSshTarget().getUserInfo());
55  			session.setConfig("userauth.gssapi-with-mic", UserAuthGSSAPIWithMIC.class.getName());
56  			session.setServerAliveInterval(1000);
57  			session.connect();
58  			if (log.isTraceEnabled())
59  				log.trace("Connected to " + getSshTarget() + " via SSH");
60  			if (sshTarget.getSession() != null) {
61  				if (log.isTraceEnabled())
62  					log.trace("The cached session to " + getSshTarget() + " was disconnected and was reset.");
63  				sshTarget.setSession(session);
64  			}
65  			return session;
66  		} catch (JSchException e) {
67  			if (sshTarget.getUserInfo() instanceof SimpleUserInfo)
68  				((SimpleUserInfo) sshTarget.getUserInfo()).reset();
69  			throw new SlcException("Could not open session to " + getSshTarget(), e);
70  		}
71  	}
72  
73  	public void run() {
74  		Session session = openSession();
75  		try {
76  			run(session);
77  		} finally {
78  			if (sshTarget != null && sshTarget.getSession() == null) {
79  				session.disconnect();
80  				if (log.isTraceEnabled())
81  					log.trace("Disconnected from " + getSshTarget() + " via SSH");
82  			}
83  		}
84  	}
85  
86  	abstract void run(Session session);
87  
88  	protected int checkAck(InputStream in) throws IOException {
89  		int b = in.read();
90  		// b may be 0 for success,
91  		// 1 for error,
92  		// 2 for fatal error,
93  		// -1
94  		if (b == 0)
95  			return b;
96  		else if (b == -1)
97  			return b;// throw new SlcException("SSH ack returned -1");
98  		else if (b == 1 || b == 2) {
99  			StringBuffer sb = new StringBuffer();
100 			int c;
101 			do {
102 				c = in.read();
103 				sb.append((char) c);
104 			} while (c != '\n');
105 			if (b == 1) { // error
106 				throw new SlcException("SSH ack error: " + sb.toString());
107 			}
108 			if (b == 2) { // fatal error
109 				throw new SlcException("SSH fatal error: " + sb.toString());
110 			}
111 		}
112 		return b;
113 	}
114 
115 	public SshTarget getSshTarget() {
116 		if (sshTarget == null)
117 			throw new SlcException("No SSH target defined.");
118 		return sshTarget;
119 	}
120 
121 	public void setSshTarget(SshTarget sshTarget) {
122 		this.sshTarget = sshTarget;
123 	}
124 
125 	PrivilegedAction<Void> asPrivilegedAction() {
126 		return new PrivilegedAction<Void>() {
127 			public Void run() {
128 				AbstractJschTask.this.run();
129 				return null;
130 			}
131 		};
132 	}
133 
134 	static {
135 		JSch.setLogger(new JschLogger());
136 	}
137 
138 	private static class JschLogger implements Logger {
139 		private final Log log = LogFactory.getLog(JschLogger.class);
140 
141 		// TODO better support levels
142 		@Override
143 		public boolean isEnabled(int level) {
144 			if (log.isTraceEnabled())
145 				return true;
146 			return false;
147 		}
148 
149 		@Override
150 		public void log(int level, String message) {
151 			log.trace(message);
152 		}
153 
154 	}
155 }